Skip to content

Commit 8196de1

Browse files
committed
Skip unnecessary compilation
1 parent 8c1889e commit 8196de1

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

megatron/initialize.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,15 @@ def _compile_dependencies():
158158
print('>>> done with dataset index builder. Compilation time: {:.3f} '
159159
'seconds'.format(time.time() - start_time), flush=True)
160160

161+
try:
162+
# Skip the rest if the kernels are unnecessary or already available (ex. from apex)
163+
if args.use_flash_attn or args.masked_softmax_fusion:
164+
import scaled_upper_triang_masked_softmax_cuda
165+
import scaled_masked_softmax_cuda
166+
return
167+
except ImportError:
168+
pass
169+
161170
# ==================
162171
# Load fused kernels
163172
# ==================

0 commit comments

Comments
 (0)