@@ -244,7 +244,21 @@ def __repr__(self) -> str:
244244 return (
245245 f"{ self .__class__ .__name__ } "
246246 f"(in={ self .in_features } , out={ self .out_features } , "
247- f"bias={ self .has_bias } , fp8_config={ self .linear_config } )"
247+ f"bias={ self .has_bias } , fp8_config={ self ._repr_fp8_config ()} )"
248+ )
249+
250+ def _repr_fp8_config (self ) -> str :
251+ return (
252+ "("
253+ "acts: ("
254+ f"dynamic: { self .linear_config ['input_activations' ]['dynamic' ]} , "
255+ f"strategy: { self .linear_config ['input_activations' ]['strategy' ]} "
256+ "), "
257+ "weights: ("
258+ f"dynamic: { self .linear_config ['weights' ]['dynamic' ]} , "
259+ f"strategy: { self .linear_config ['weights' ]['strategy' ]} "
260+ ")"
261+ ")"
248262 )
249263
250264 def get_fp8_linear (
@@ -266,14 +280,14 @@ def shard_fp8_linear(
266280 sharding | param | shard | dim |
267281 ----------+----------------+-------+-----|
268282 colwise | weight | Y | 0 |
269- | weight_scale | N | - |
270- | input_scale | N | - |
271- | bias | Y | 0 |
283+ | weight_scale | N | - |
284+ | input_scale | N | - |
285+ | bias | Y | 0 |
272286 ----------+----------------+-------+-----|
273287 rowwise | weight | Y | 1 |
274- | weight_scale | Y/N | 0/- |
275- | input_scale | Y/N | 0/- |
276- | bias | 0 | - |
288+ | weight_scale | Y/N | 0/- |
289+ | input_scale | Y/N | 0/- |
290+ | bias | 0 | - |
277291 """
278292
279293 param_sharding_info : dict [str , dict [str , LinearParameterShardingInfo ]] = {}
0 commit comments