@@ -141,46 +141,24 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
141141}
142142
143143void CreateNdDescOp::build (OpBuilder &builder, OperationState &state,
144- Type tdesc, TypedValue<MemRefType> source,
144+ Type tdesc, Value source,
145145 llvm::ArrayRef<OpFoldResult> offsets,
146146 llvm::ArrayRef<OpFoldResult> shape,
147147 llvm::ArrayRef<OpFoldResult> strides) {
148148 assert (shape.size () && offsets.size () && strides.size () &&
149149 shape.size () == strides.size () && shape.size () == offsets.size ());
150150
151- llvm::SmallVector<int64_t > staticOffsets;
152- llvm::SmallVector<int64_t > staticShape;
153- llvm::SmallVector<int64_t > staticStrides;
151+ auto intTy = dyn_cast<IntegerType>(source.getType ());
152+ auto memrefTy = dyn_cast<MemRefType>(source.getType ());
153+ assert (intTy || memrefTy && " Source has to be either int or memref." );
154+
154155 llvm::SmallVector<Value> dynamicOffsets;
155156 llvm::SmallVector<Value> dynamicShape;
156157 llvm::SmallVector<Value> dynamicStrides;
157158
158- dispatchIndexOpFoldResults (offsets, dynamicOffsets, staticOffsets);
159- dispatchIndexOpFoldResults (shape, dynamicShape, staticShape);
160- dispatchIndexOpFoldResults (strides, dynamicStrides, staticStrides);
161-
162- auto staticOffsetsAttr = builder.getDenseI64ArrayAttr (staticOffsets);
163- auto staticShapeAttr = builder.getDenseI64ArrayAttr (staticShape);
164- auto staticStridesAttr = builder.getDenseI64ArrayAttr (staticStrides);
165-
166- build (builder, state, tdesc, source, dynamicOffsets, dynamicShape,
167- dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
168- }
169-
170- void CreateNdDescOp::build (OpBuilder &builder, OperationState &state,
171- Type tdesc, TypedValue<IntegerType> source,
172- llvm::ArrayRef<OpFoldResult> offsets,
173- llvm::ArrayRef<OpFoldResult> shape,
174- llvm::ArrayRef<OpFoldResult> strides) {
175- assert (shape.size () && offsets.size () && strides.size () &&
176- shape.size () == strides.size () && shape.size () == offsets.size ());
177-
178159 llvm::SmallVector<int64_t > staticOffsets;
179160 llvm::SmallVector<int64_t > staticShape;
180161 llvm::SmallVector<int64_t > staticStrides;
181- llvm::SmallVector<Value> dynamicOffsets;
182- llvm::SmallVector<Value> dynamicShape;
183- llvm::SmallVector<Value> dynamicStrides;
184162
185163 dispatchIndexOpFoldResults (offsets, dynamicOffsets, staticOffsets);
186164 dispatchIndexOpFoldResults (shape, dynamicShape, staticShape);
@@ -190,6 +168,17 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
190168 auto staticShapeAttr = builder.getDenseI64ArrayAttr (staticShape);
191169 auto staticStridesAttr = builder.getDenseI64ArrayAttr (staticStrides);
192170
171+ if (memrefTy) {
172+ auto memrefShape = memrefTy.getShape ();
173+ auto [memrefStrides, offset] = memrefTy.getStridesAndOffset ();
174+
175+ // if shape and strides are from Memref, we don't need attributes for them
176+ if (staticShape == memrefShape && staticStrides == memrefStrides) {
177+ staticShapeAttr = DenseI64ArrayAttr ();
178+ staticStridesAttr = DenseI64ArrayAttr ();
179+ }
180+ }
181+
193182 build (builder, state, tdesc, source, dynamicOffsets, dynamicShape,
194183 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
195184}
0 commit comments