-
Notifications
You must be signed in to change notification settings - Fork 349
[not4land] Some fixes for MXFP8 #3183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -665,7 +665,10 @@ def _addmm_mx_dispatch( | |
The only difference is whether bias is None or not. | ||
""" | ||
|
||
out_starting_shape = a.shape[:-1] | ||
|
||
if not isinstance(a, MXTensor): | ||
a = a.reshape(-1, a.shape[-1]) | ||
assert b.act_quant_kwargs is not None, "weight-only quant not yet supported" | ||
k = b.act_quant_kwargs | ||
a = MXTensor.to_mx( | ||
|
@@ -698,6 +701,7 @@ def _addmm_mx_dispatch( | |
"CUBLAS is the only supported kernel choice for MX FP8 operations" | ||
) | ||
|
||
print(f"scaled_mm info: {a.qdata.shape=}, {b.qdata.shape=}, {a_scale_block.shape=}, {b_scale_block.shape=}, a.qdata contig: {a.qdata.is_contiguous()} b.qdata contig: {b.qdata.is_contiguous()} a_scale_block contig: {a_scale_block.is_contiguous()}, b_scale_block contig: {b_scale_block.is_contiguous()}") | ||
res = torch._scaled_mm( | ||
a.qdata, | ||
b.qdata, | ||
|
@@ -706,6 +710,7 @@ def _addmm_mx_dispatch( | |
bias=bias, | ||
out_dtype=torch.bfloat16, | ||
) | ||
print(f"after scaled mm {a.qdata.shape=}, {b.qdata.shape=}") | ||
else: | ||
assert a._elem_dtype == torch.float4_e2m1fn_x2 | ||
assert b._elem_dtype == torch.float4_e2m1fn_x2 | ||
|
@@ -717,7 +722,6 @@ def _addmm_mx_dispatch( | |
# TODO add optional bias to kernel | ||
if bias is not None: | ||
res = res + bias | ||
|
||
else: | ||
# emulated MX gemm | ||
a_hp = a.dequantize(a._orig_dtype) | ||
|
@@ -726,12 +730,17 @@ def _addmm_mx_dispatch( | |
assert a_hp.is_contiguous() | ||
assert b_hp.t().is_contiguous() | ||
|
||
# Call appropriate aten_op based on whether bias is provided | ||
if bias is not None: | ||
res = aten_op(bias, a_hp, b_hp) # addmm | ||
if aten_op == aten.linear.default: | ||
res = aten_op(a_hp, b_hp.t(), bias) | ||
else: | ||
res = aten_op(a_hp, b_hp) # mm | ||
# Call appropriate aten_op based on whether bias is provided | ||
if bias is not None: | ||
res = aten_op(bias, a_hp, b_hp) # addmm | ||
else: | ||
res = aten_op(a_hp, b_hp) # mm | ||
|
||
res = res.reshape(*out_starting_shape, res.shape[-1]) | ||
|
||
return res | ||
|
||
|
||
|
@@ -752,6 +761,14 @@ def mx_addmm(func, types, args, kwargs): | |
b = args[2] | ||
return _addmm_mx_dispatch(a, b, func, bias=bias) | ||
|
||
@implements([aten.linear.default]) | ||
def mx_linear(func, types, args, kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the reshape should live here, the |
||
assert isinstance(args[0], torch.Tensor) and isinstance(args[1], MXTensor) | ||
a = args[0] | ||
b = args[1] | ||
bias = args[2] if len(args) > 2 else None | ||
return _addmm_mx_dispatch(a, b.t(), func, bias=bias) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO either don't pass |
||
|
||
|
||
@implements([aten.t.default]) | ||
def mx_t(func, types, args, kwargs): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this what makes the test fail, without changes to the product code?