Skip to content

Commit 7e7277b

Browse files
committed
[Fix] Missed import and dim reference
Signed-off-by: Gene Der Su <[email protected]>
1 parent 409b43a commit 7e7277b

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""Installation script."""
88

99
import os
10+
import sys
1011
import time
1112
from pathlib import Path
1213
from typing import List, Tuple

transformer_engine/jax/flax/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ def __call__(
637637
dropout_rng = self.make_rng(self.dropout_rng_name)
638638

639639
if self.scale_factor is None:
640-
scale_factor = 1.0 / sqrt(self.head_dim_qk)
640+
scale_factor = 1.0 / sqrt(head_dim_qk)
641641
else:
642642
scale_factor = self.scale_factor
643643
del self.scale_factor

0 commit comments

Comments
 (0)