@@ -156,48 +156,37 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
156156}
157157
158158void CreateNdDescOp::build (OpBuilder &builder, OperationState &state,
159- Type tdesc, TypedValue<MemRefType> source,
159+ Type tdesc, Value source,
160160 llvm::ArrayRef<OpFoldResult> shape,
161161 llvm::ArrayRef<OpFoldResult> strides) {
162- assert (shape. size () && strides. size () && shape. size () == strides. size () &&
163- " Shape and strides must be present and of equal size for ui64 "
164- " initialization ." );
162+ Type srcTy = source. getType ();
163+ assert ((isa<IntegerType, MemRefType>(srcTy)) &&
164+ " Source has to be either int or memref ." );
165165
166- llvm::SmallVector<int64_t > staticShape;
167- llvm::SmallVector<int64_t > staticStrides;
168166 llvm::SmallVector<Value> dynamicShape;
169167 llvm::SmallVector<Value> dynamicStrides;
170168
171- dispatchIndexOpFoldResults (shape, dynamicShape, staticShape);
172- dispatchIndexOpFoldResults (strides, dynamicStrides, staticStrides);
173-
174- auto staticShapeAttr = builder.getDenseI64ArrayAttr (staticShape);
175- auto staticStridesAttr = builder.getDenseI64ArrayAttr (staticStrides);
176-
177- build (builder, state, tdesc, source, ValueRange ({}), dynamicShape,
178- dynamicStrides, builder.getDenseI64ArrayAttr ({}), staticShapeAttr,
179- staticStridesAttr);
180- }
181-
182- void CreateNdDescOp::build (OpBuilder &builder, OperationState &state,
183- Type tdesc, TypedValue<IntegerType> source,
184- llvm::ArrayRef<OpFoldResult> shape,
185- llvm::ArrayRef<OpFoldResult> strides) {
186- assert (shape.size () && strides.size () && shape.size () == strides.size () &&
187- " Shape and strides must be present and of equal size for ui64 "
188- " initialization." );
189-
190169 llvm::SmallVector<int64_t > staticShape;
191170 llvm::SmallVector<int64_t > staticStrides;
192- llvm::SmallVector<Value> dynamicShape;
193- llvm::SmallVector<Value> dynamicStrides;
194171
195172 dispatchIndexOpFoldResults (shape, dynamicShape, staticShape);
196173 dispatchIndexOpFoldResults (strides, dynamicStrides, staticStrides);
197174
198175 auto staticShapeAttr = builder.getDenseI64ArrayAttr (staticShape);
199176 auto staticStridesAttr = builder.getDenseI64ArrayAttr (staticStrides);
200177
178+ if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
179+ auto memrefShape = memrefTy.getShape ();
180+ auto [memrefStrides, _] = memrefTy.getStridesAndOffset ();
181+
182+ // if shape and strides are from Memref, we don't need attributes for them
183+ // to keep the IR print clean.
184+ if (staticShape == memrefShape && staticStrides == memrefStrides) {
185+ staticShapeAttr = DenseI64ArrayAttr ();
186+ staticStridesAttr = DenseI64ArrayAttr ();
187+ }
188+ }
189+
201190 build (builder, state, tdesc, source, ValueRange ({}), dynamicShape,
202191 dynamicStrides, builder.getDenseI64ArrayAttr ({}), staticShapeAttr,
203192 staticStridesAttr);
@@ -357,13 +346,10 @@ ParseResult parseOptionalDynamicIndexList(
357346void printOptionalDynamicIndexList (OpAsmPrinter &printer, Operation *op,
358347 OperandRange values,
359348 DenseI64ArrayAttr integers) {
360-
361- if (!integers)
349+ if (!integers || integers.empty ())
362350 return ;
363-
364- return printDynamicIndexList (printer, op, values, integers,
365- /* scalableFlags=*/ {}, {},
366- AsmParser::Delimiter::Square);
351+ printDynamicIndexList (printer, op, values, integers,
352+ /* scalableFlags=*/ {}, {}, AsmParser::Delimiter::Square);
367353}
368354// ===----------------------------------------------------------------------===//
369355// XeGPU_PrefetchNdOp
0 commit comments