Skip to content

Commit 9a2257b

Browse files
authored
Merge pull request #332 from ROCm/genesu/fix-small-bugs
[Fix] Missed import and dim reference
2 parents 409b43a + 09f67f6 commit 9a2257b

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
152152
if not found_cmake():
153153
setup_reqs.append("cmake>=3.21")
154154
if not found_ninja():
155+
import sys
156+
155157
subprocess.check_call([sys.executable, "-m", "pip", "install", "ninja"])
156158
setup_reqs.append("ninja")
157159
if not found_pybind11():

transformer_engine/jax/flax/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ def __call__(
637637
dropout_rng = self.make_rng(self.dropout_rng_name)
638638

639639
if self.scale_factor is None:
640-
scale_factor = 1.0 / sqrt(self.head_dim_qk)
640+
scale_factor = 1.0 / sqrt(head_dim_qk)
641641
else:
642642
scale_factor = self.scale_factor
643643
del self.scale_factor

0 commit comments

Comments
 (0)