File tree Expand file tree Collapse file tree 2 files changed +4
-2
lines changed
Expand file tree Collapse file tree 2 files changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -48,6 +48,8 @@ QAT of Qwen3-8B NVFP4 recovers most of the accuracy on the MMLU benchmark after
4848| Qwen3-8B NVFP4 | 70.3 |
4949| Qwen3-8B NVFP4 after QAT | 72.8 |
5050
51+ The resulting exported checkpoint also is much smaller in memory at 6.4GB compared to the original BF16 checkpoint which is 16.4 GB.
52+
5153## Usage
5254
5355### Prerequisites
Original file line number Diff line number Diff line change @@ -140,7 +140,7 @@ def get_args():
140140 action = "store_true" ,
141141 default = False ,
142142 )
143- parser .add_argument ("--tensor_parallelism" , type = int , default = 1 )
143+ parser .add_argument ("--tensor_parallelism" , type = int , default = 2 )
144144 parser .add_argument ("--pipeline_parallelism" , type = int , default = 1 )
145145 return parser .parse_args ()
146146
@@ -375,7 +375,7 @@ def main(args):
375375 SEQUENCE_LENGTH = 4096
376376 MBS = 1
377377 GBS = 512
378- TRAIN_STEPS = 400
378+ TRAIN_STEPS = 200
379379 VAL_INTERVAL = 50
380380 # # # # # # # # # # # # # # # # # # # # # #
381381
You can’t perform that action at this time.
0 commit comments