@@ -141,46 +141,24 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
141141}
142142
143143void CreateNdDescOp::build (OpBuilder &builder, OperationState &state,
144- Type tdesc, TypedValue<MemRefType> source,
144+ Type tdesc, Value source,
145145 llvm::ArrayRef<OpFoldResult> offsets,
146146 llvm::ArrayRef<OpFoldResult> shape,
147147 llvm::ArrayRef<OpFoldResult> strides) {
148148 assert (shape.size () && offsets.size () && strides.size () &&
149149 shape.size () == strides.size () && shape.size () == offsets.size ());
150150
151- llvm::SmallVector<int64_t > staticOffsets;
152- llvm::SmallVector<int64_t > staticShape;
153- llvm::SmallVector<int64_t > staticStrides;
151+ Type srcTy = source.getType ();
152+ assert (isa<IntegerType>(srcTy) ||
153+ isa<MemRefType>(srcTy) && " Source has to be either int or memref." );
154+
154155 llvm::SmallVector<Value> dynamicOffsets;
155156 llvm::SmallVector<Value> dynamicShape;
156157 llvm::SmallVector<Value> dynamicStrides;
157158
158- dispatchIndexOpFoldResults (offsets, dynamicOffsets, staticOffsets);
159- dispatchIndexOpFoldResults (shape, dynamicShape, staticShape);
160- dispatchIndexOpFoldResults (strides, dynamicStrides, staticStrides);
161-
162- auto staticOffsetsAttr = builder.getDenseI64ArrayAttr (staticOffsets);
163- auto staticShapeAttr = builder.getDenseI64ArrayAttr (staticShape);
164- auto staticStridesAttr = builder.getDenseI64ArrayAttr (staticStrides);
165-
166- build (builder, state, tdesc, source, dynamicOffsets, dynamicShape,
167- dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
168- }
169-
170- void CreateNdDescOp::build (OpBuilder &builder, OperationState &state,
171- Type tdesc, TypedValue<IntegerType> source,
172- llvm::ArrayRef<OpFoldResult> offsets,
173- llvm::ArrayRef<OpFoldResult> shape,
174- llvm::ArrayRef<OpFoldResult> strides) {
175- assert (shape.size () && offsets.size () && strides.size () &&
176- shape.size () == strides.size () && shape.size () == offsets.size ());
177-
178159 llvm::SmallVector<int64_t > staticOffsets;
179160 llvm::SmallVector<int64_t > staticShape;
180161 llvm::SmallVector<int64_t > staticStrides;
181- llvm::SmallVector<Value> dynamicOffsets;
182- llvm::SmallVector<Value> dynamicShape;
183- llvm::SmallVector<Value> dynamicStrides;
184162
185163 dispatchIndexOpFoldResults (offsets, dynamicOffsets, staticOffsets);
186164 dispatchIndexOpFoldResults (shape, dynamicShape, staticShape);
@@ -190,6 +168,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
190168 auto staticShapeAttr = builder.getDenseI64ArrayAttr (staticShape);
191169 auto staticStridesAttr = builder.getDenseI64ArrayAttr (staticStrides);
192170
171+ if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
172+ auto memrefShape = memrefTy.getShape ();
173+ auto [memrefStrides, _] = memrefTy.getStridesAndOffset ();
174+
175+ // if shape and strides are from Memref, we don't need attributes for them
176+ // to keep the IR print clean.
177+ if (staticShape == memrefShape && staticStrides == memrefStrides) {
178+ staticShapeAttr = DenseI64ArrayAttr ();
179+ staticStridesAttr = DenseI64ArrayAttr ();
180+ }
181+ }
182+
193183 build (builder, state, tdesc, source, dynamicOffsets, dynamicShape,
194184 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
195185}
0 commit comments