Skip to content

Commit 2f79364

Browse files
authored
enable powers of 2 cast in float8 rowwise_with_gw_hp recipe (#2677)
Summary: This should have been enabled from the time we added the powers of 2 scaling, fixing. Test Plan: ``` with-proxy CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.print_after_conversion --model.converters float8 --training.compile --float8.recipe_name rowwise_with_gw_hp ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 5d99ce4 commit 2f79364

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

torchao/float8/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def from_recipe_name(
333333
cast_config_input_for_grad_weight=cc_i_gw,
334334
cast_config_weight_for_grad_input=cc_w_gi,
335335
cast_config_grad_output_for_grad_weight=cc_go_gw,
336+
round_scales_to_power_of_2=True,
336337
)
337338

338339
else:

0 commit comments

Comments
 (0)