File tree Expand file tree Collapse file tree 2 files changed +7
-2
lines changed
source_transformation/torchtune Expand file tree Collapse file tree 2 files changed +7
-2
lines changed Original file line number Diff line number Diff line change 7070 replace_sdpa_with_simple_sdpa ,
7171)
7272
73- from .source_transformation .vulkan_rope import replace_with_vulkan_rotary_emb
74-
7573from .source_transformation .torchtune .attention import replace_mha_with_inference_mha
7674
75+ from .source_transformation .vulkan_rope import replace_with_vulkan_rotary_emb
76+
7777
7878IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False)
7979FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -1019,4 +1019,8 @@ def _get_source_transforms( # noqa
10191019 if args .vulkan :
10201020 transforms .append (replace_with_vulkan_rotary_emb )
10211021
1022+ print (
1023+ f"Performing the following source transformations: { [transform .__name__ for transform in transforms ]} "
1024+ )
1025+
10221026 return transforms
Original file line number Diff line number Diff line change @@ -32,6 +32,7 @@ def _replace_mha_with_inference_mha(module: torch.nn.Module) -> None:
3232 else :
3333 replace_mha_with_inference_mha (child )
3434
35+
3536def replace_mha_with_inference_mha (module : torch .nn .Module ) -> torch .nn .Module :
3637 """
3738 Replace TorchTune's MHA with an inference friendly version of MHA that
You can’t perform that action at this time.
0 commit comments