Skip to content

Commit 4a33f29

Browse files
committed
fused kernel import
1 parent e8c74d5 commit 4a33f29

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

tools/checkpoint_loader_megatron.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import torch
77

8-
from megatron import fused_kernels
98

109
def add_arguments(parser):
1110
group = parser.add_argument_group(title='Megatron loader')
@@ -34,7 +33,7 @@ def _load_checkpoint(queue, args):
3433
from megatron.global_vars import set_args, set_global_variables
3534
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint
3635
from megatron.model import ModelType, module
37-
from megatron import mpu
36+
from megatron import mpu, fused_kernels
3837
except ModuleNotFoundError:
3938
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
4039
queue.put("exit")

tools/checkpoint_saver_megatron.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import torch
88

9-
from megatron import fused_kernels
109

1110
def add_arguments(parser):
1211
group = parser.add_argument_group(title='Megatron saver')
@@ -38,7 +37,7 @@ def save_checkpoint(queue, args):
3837
from megatron.global_vars import set_global_variables, get_args
3938
from megatron.model import ModelType
4039
from megatron.tokenizer.tokenizer import _vocab_size_with_padding
41-
from megatron import mpu
40+
from megatron import mpu, fused_kernels
4241
except ModuleNotFoundError:
4342
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
4443
exit(1)

0 commit comments

Comments
 (0)