-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][XeGPU] make offsets optional for create_nd_tdesc #148335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
b9a6d98
2465050
1077871
42baa22
204d347
2793c81
f23ea03
0bb958b
6793689
4a96c71
689a8a5
02d3795
01718f4
5ef6ca9
26a222d
882313f
456534a
b6f016e
cd518d2
546a3f7
7846955
ded9552
97b6e39
ed1d48e
b3edff6
d3e935b
205fea7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -110,23 +110,34 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface | |
| Variadic<Index>: $offsets, | ||
| Variadic<Index>: $shape, | ||
| Variadic<Index>: $strides, | ||
| DenseI64ArrayAttr: $const_offsets, | ||
| OptionalAttr<DenseI64ArrayAttr>: $const_offsets, | ||
| OptionalAttr<DenseI64ArrayAttr>: $const_shape, | ||
| OptionalAttr<DenseI64ArrayAttr>: $const_strides | ||
| ); | ||
| let results = (outs XeGPU_TensorDesc: $TensorDesc); | ||
|
|
||
| let assemblyFormat = [{ | ||
| $source `` | ||
| custom<DynamicIndexList>($offsets, $const_offsets) | ||
| (`,` custom<DynamicIndexList>($shape, $const_shape)^ | ||
| `,` custom<DynamicIndexList>($strides, $const_strides))? | ||
| custom<OptionalDynamicIndexList>($offsets, $const_offsets) | ||
| (`,` `shape` `:` custom<DynamicIndexList>($shape, $const_shape)^ | ||
| `,` `strides``:` custom<DynamicIndexList>($strides, $const_strides))? | ||
| attr-dict `:` type($source) `->` qualified(type($TensorDesc)) | ||
| }]; | ||
|
|
||
| let results = (outs XeGPU_TensorDesc: $TensorDesc); | ||
|
|
||
| let hasVerifier = 1; | ||
|
|
||
| let builders = [ | ||
| OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>, | ||
|
|
||
| OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source, | ||
| "llvm::ArrayRef<OpFoldResult>": $shape, | ||
| "llvm::ArrayRef<OpFoldResult>": $strides)>, | ||
|
|
||
| OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source, | ||
| "llvm::ArrayRef<OpFoldResult>": $shape, | ||
| "llvm::ArrayRef<OpFoldResult>": $strides)>, | ||
|
|
||
| OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source, | ||
| "llvm::ArrayRef<OpFoldResult>": $offsets)>, | ||
|
|
||
|
|
@@ -163,9 +174,29 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface | |
| } | ||
|
|
||
| ArrayRef<int64_t> getStaticOffsets(){ | ||
| return getConstOffsets(); | ||
| auto attr = getConstOffsetsAttr(); | ||
|
|
||
| if (attr) | ||
| return attr; | ||
|
|
||
| auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType()); | ||
| int rank = 0; | ||
| if (memrefType) | ||
| //use source memref's rank, as source memref rank may be higher | ||
| rank = memrefType.getRank(); | ||
| else | ||
| //nd_tdesc created from ui64, use nd_tdesc's rank | ||
| rank = getMixedSizes().size(); | ||
|
||
|
|
||
| // The offsets are allowed to be empty. The Traits verification of OffsetSizeAndStrideOpInterface interface assumes offsets being present. | ||
| // It is set to be MAX to indicate user not passed any value, instead of kDynamic which means offsets passed as value. | ||
| setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, std::numeric_limits<int64_t>::max())); | ||
|
|
||
|
||
| attr = getConstOffsetsAttr(); | ||
| return attr; | ||
| } | ||
|
|
||
|
|
||
| /// wrapper for matching with OffsetSizeAndStrideOpInterface | ||
| /// If source is IntegerType or `const_shape` is filled, | ||
| /// it will return `const_shape`, such that mixes of `shape` | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -112,6 +112,64 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, | |||||
| //===----------------------------------------------------------------------===// | ||||||
| // XeGPU_CreateNdDescOp | ||||||
| //===----------------------------------------------------------------------===// | ||||||
|
|
||||||
| void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | ||||||
| Type tdesc, TypedValue<MemRefType> source) { | ||||||
| [[maybe_unused]] auto ty = source.getType(); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ty is only used in the assert statement which is unused in release binary. |
||||||
| assert(ty.hasStaticShape()); | ||||||
|
||||||
|
|
||||||
| build(builder, state, tdesc, source, ValueRange({}) /* dynamic offsets */, | ||||||
| ValueRange({}) /* empty dynamic shape */, | ||||||
| ValueRange({}) /* empty dynamic strides */, | ||||||
| builder.getDenseI64ArrayAttr({}) /* const offsets */, | ||||||
adam-smnk marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| builder.getDenseI64ArrayAttr({}) /* empty const shape*/, | ||||||
| builder.getDenseI64ArrayAttr({}) /* empty const strides*/); | ||||||
| } | ||||||
|
|
||||||
| void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | ||||||
| Type tdesc, TypedValue<MemRefType> source, | ||||||
| llvm::ArrayRef<OpFoldResult> shape, | ||||||
| llvm::ArrayRef<OpFoldResult> strides) { | ||||||
| assert(shape.size() && strides.size() && shape.size() == strides.size()); | ||||||
|
||||||
|
|
||||||
| llvm::SmallVector<int64_t> staticShape; | ||||||
| llvm::SmallVector<int64_t> staticStrides; | ||||||
| llvm::SmallVector<Value> dynamicShape; | ||||||
| llvm::SmallVector<Value> dynamicStrides; | ||||||
|
|
||||||
| dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); | ||||||
| dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); | ||||||
|
|
||||||
| auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); | ||||||
| auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); | ||||||
|
|
||||||
| build(builder, state, tdesc, source, ValueRange({}), dynamicShape, | ||||||
| dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr, | ||||||
| staticStridesAttr); | ||||||
| } | ||||||
|
|
||||||
| void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | ||||||
| Type tdesc, TypedValue<IntegerType> source, | ||||||
| llvm::ArrayRef<OpFoldResult> shape, | ||||||
| llvm::ArrayRef<OpFoldResult> strides) { | ||||||
| assert(shape.size() && strides.size() && shape.size() == strides.size()); | ||||||
|
||||||
|
|
||||||
| llvm::SmallVector<int64_t> staticShape; | ||||||
| llvm::SmallVector<int64_t> staticStrides; | ||||||
| llvm::SmallVector<Value> dynamicShape; | ||||||
| llvm::SmallVector<Value> dynamicStrides; | ||||||
|
|
||||||
| dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); | ||||||
| dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); | ||||||
|
|
||||||
| auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); | ||||||
| auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); | ||||||
|
|
||||||
| build(builder, state, tdesc, source, ValueRange({}), dynamicShape, | ||||||
| dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr, | ||||||
| staticStridesAttr); | ||||||
| } | ||||||
|
|
||||||
| void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | ||||||
| Type tdesc, TypedValue<MemRefType> source, | ||||||
| llvm::ArrayRef<OpFoldResult> offsets) { | ||||||
|
|
@@ -125,8 +183,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | |||||
| build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */, | ||||||
| ValueRange({}) /* empty dynamic shape */, | ||||||
| ValueRange({}) /* empty dynamic strides */, | ||||||
| staticOffsets /* const offsets */, {} /* empty const shape*/, | ||||||
| {} /* empty const strides*/); | ||||||
| builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */, | ||||||
| {} /* empty const shape*/, {} /* empty const strides*/); | ||||||
| } | ||||||
|
|
||||||
| void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where is the new build methods implemented?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
||||||
|
|
@@ -221,6 +279,66 @@ LogicalResult CreateNdDescOp::verify() { | |||||
| return success(); | ||||||
| } | ||||||
|
|
||||||
| ParseResult parseOptionalDynamicIndexList( | ||||||
| OpAsmParser &parser, | ||||||
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, | ||||||
| DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr, | ||||||
| AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { | ||||||
|
|
||||||
| SmallVector<int64_t, 4> integerVals; | ||||||
| auto parseIntegerOrValue = [&]() { | ||||||
| OpAsmParser::UnresolvedOperand operand; | ||||||
| auto res = parser.parseOptionalOperand(operand); | ||||||
|
|
||||||
| if (res.has_value() && succeeded(res.value())) { | ||||||
| values.push_back(operand); | ||||||
| integerVals.push_back(ShapedType::kDynamic); | ||||||
| if (valueTypes && parser.parseColonType(valueTypes->emplace_back())) | ||||||
| return failure(); | ||||||
| } else { | ||||||
| int64_t integer; | ||||||
| if (failed(parser.parseInteger(integer))) | ||||||
| return failure(); | ||||||
| integerVals.push_back(integer); | ||||||
| } | ||||||
|
|
||||||
| return success(); | ||||||
| }; | ||||||
| if (parser.parseOptionalLSquare().succeeded()) { | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume that for no-offset case this check will fail?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add a comment here like "If the optional values are given there must be left bracket"
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. Added |
||||||
| if (parser.parseCommaSeparatedList(parseIntegerOrValue) || | ||||||
| parser.parseRSquare()) | ||||||
| return parser.emitError(parser.getNameLoc()) | ||||||
| << "expected SSA value or integer"; | ||||||
|
||||||
| << "expected SSA value or integer"; | |
| << "expected a list of SSA values or integers"; |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add a comment here to explain that here we use some place holder values to carry optional offsets.
nit: I still prefer if this constant is defined somewhere and properly documented. Will make life easier for future changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is temporary. I added a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can remove the valueTypes and related logic, since it is not used.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does values and valueTypes always have same size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems correct from the test cases. This interface is automatically generated by parser which should guarantee it.
Uh oh!
There was an error while loading. Please reload this page.