Skip to content

Conversation

@etiotto
Copy link
Contributor

@etiotto etiotto commented Nov 22, 2024

This PR decomposed a tt.dot_scaled operation into a tt.dot operation where one of the operands (e.g A) is scaled using the triton_gpu_upcast_mxfp operation.

Note: The upcast_mxfp operation is not lowered to LLVM IR in this PR.

@etiotto etiotto self-assigned this Nov 22, 2024
@etiotto etiotto linked an issue Nov 22, 2024 that may be closed by this pull request
@etiotto etiotto marked this pull request as ready for review November 25, 2024 14:25
@anmyachev
Copy link
Contributor

@etiotto in order to test these changes we need to unskip test_scaled_dot:

if is_xpu():
pytest.skip("scaled_dot isn't supported on XPU")

Did you plan to do this here?

@etiotto
Copy link
Contributor Author

etiotto commented Nov 26, 2024

@etiotto in order to test these changes we need to unskip test_scaled_dot:

if is_xpu():
pytest.skip("scaled_dot isn't supported on XPU")

Did you plan to do this here?

Yes we need to do that, however the code is not yet fully functional because the lowering code for the triton_gpu.upcast_mxfp operation is not working yet. I will remove that part of the PR and just deal with decomposing the dot_scaled operation in this particular PR.

@etiotto etiotto marked this pull request as draft November 26, 2024 00:38
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
Signed-off-by: Tiotto, Ettore <[email protected]>
@etiotto etiotto marked this pull request as ready for review November 26, 2024 21:15
@victor-eds
Copy link
Contributor

PR approach LGTM. Just some NITs.

@whitneywhtsang
Copy link
Contributor

PR approach LGTM. Just some NITs.

@victor-eds Maybe you forgot to submit the NITs?

Copy link
Contributor

@leonling-ll leonling-ll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

newShape[kIdx] *= 2;
retTy = RankedTensorType::get(newShape, FloatType::getBF16(ctx),
newVEncoding);
Type elemType = FloatType::getBF16(ctx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define this inside the if statement below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rather have it here because elemType is used after the if/else at line 147

Comment on lines +6 to +8
namespace mlir {
class ModuleOp;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd bet we don't need this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ModuleOp is used in "intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.h.inc" now:

static DPASCapability getDPASCapability(mlir::ModuleOp mod);

That is the reason I have put the forward declaration here.

Signed-off-by: Tiotto, Ettore <[email protected]>
@etiotto etiotto enabled auto-merge (squash) November 29, 2024 17:18
@etiotto etiotto merged commit 0c70ca3 into main Nov 29, 2024
5 checks passed
@etiotto etiotto deleted the etiotto.add_support_for_scaled_dot branch November 29, 2024 17:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement support for the tt.dot_scaled operation on XPU

6 participants