File tree Expand file tree Collapse file tree 2 files changed +4
-3
lines changed
Expand file tree Collapse file tree 2 files changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -31,6 +31,7 @@ def _infer_quantization_config(quant_config: dict) -> dict | None:
3131 # First, import required FP8 linear classes from fms-mo
3232 # Local
3333 import fms_mo .aiu_addons .fp8 .fp8_adapter # pylint: disable=unused-import
34+ import fms_mo .aiu_addons .fp8 .fp8_attn # pylint: disable=unused-import
3435 import fms_mo .aiu_addons .fp8 .fp8_linear # pylint: disable=unused-import
3536
3637 # This is used by get_linear to decide whether a linear layer
Original file line number Diff line number Diff line change @@ -321,12 +321,12 @@ def shard_fp8_linear(
321321 sharding | param | shard | dim |
322322 ----------+----------------+-------+-----|
323323 colwise | weight | Y | 0 |
324- | weight_scale | N | - |
324+ | weight_scale | Y/ N | 0/- |
325325 | input_scale | N | - |
326326 | bias | Y | 0 |
327327 ----------+----------------+-------+-----|
328328 rowwise | weight | Y | 1 |
329- | weight_scale | Y/ N | 0/- |
329+ | weight_scale | N | - |
330330 | input_scale | Y/N | 0/- |
331331 | bias | 0 | - |
332332 """
@@ -339,7 +339,7 @@ def shard_fp8_linear(
339339 ]
340340 # Scales are per-row or per-tensor
341341 # Only sharding needed when row parallel and per-row
342- shard_scales = weight_strategy != "tensor" and module_info .sharding_dim == 1
342+ shard_scales = weight_strategy != "tensor" and module_info .sharding_dim == 0
343343 params : dict [str , LinearParameterShardingInfo ] = {
344344 "weight" : LinearParameterShardingInfo (
345345 module_info .sharding_dim , ShardType .SHARD
You can’t perform that action at this time.
0 commit comments