1717#include " mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1818#include " mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1919#include " mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20+ #include " mlir/Dialect/Utils/StaticValueUtils.h"
2021#include " mlir/Dialect/Vector/IR/VectorOps.h"
2122#include " mlir/IR/Attributes.h"
2223#include " mlir/IR/BuiltinAttributes.h"
@@ -40,22 +41,9 @@ using namespace mlir;
4041// / Returns the integer value from the first valid input element, assuming Value
4142// / inputs are defined by a constant index ops and Attribute inputs are integer
4243// / attributes.
43- static uint64_t getFirstIntValue (ValueRange values) {
44- return values[0 ].getDefiningOp <arith::ConstantIndexOp>().value ();
45- }
46- static uint64_t getFirstIntValue (ArrayRef<Attribute> attr) {
47- return cast<IntegerAttr>(attr[0 ]).getInt ();
48- }
4944static uint64_t getFirstIntValue (ArrayAttr attr) {
5045 return (*attr.getAsValueRange <IntegerAttr>().begin ()).getZExtValue ();
5146}
52- static uint64_t getFirstIntValue (ArrayRef<OpFoldResult> foldResults) {
53- auto attr = foldResults[0 ].dyn_cast <Attribute>();
54- if (attr)
55- return getFirstIntValue (attr);
56-
57- return getFirstIntValue (ValueRange{foldResults[0 ].get <Value>()});
58- }
5947
6048// / Returns the number of bits for the given scalar/vector type.
6149static int getNumBits (Type type) {
@@ -157,9 +145,6 @@ struct VectorExtractOpConvert final
157145 LogicalResult
158146 matchAndRewrite (vector::ExtractOp extractOp, OpAdaptor adaptor,
159147 ConversionPatternRewriter &rewriter) const override {
160- if (extractOp.hasDynamicPosition ())
161- return failure ();
162-
163148 Type dstType = getTypeConverter ()->convertType (extractOp.getType ());
164149 if (!dstType)
165150 return failure ();
@@ -169,9 +154,17 @@ struct VectorExtractOpConvert final
169154 return success ();
170155 }
171156
172- int32_t id = getFirstIntValue (extractOp.getMixedPosition ());
173- rewriter.replaceOpWithNewOp <spirv::CompositeExtractOp>(
174- extractOp, adaptor.getVector (), id);
157+ std::optional<int64_t > id =
158+ getConstantIntValue (extractOp.getMixedPosition ()[0 ]);
159+
160+ if (id.has_value ())
161+ rewriter.replaceOpWithNewOp <spirv::CompositeExtractOp>(
162+ extractOp, dstType, adaptor.getVector (),
163+ rewriter.getI32ArrayAttr (id.value ()));
164+ else
165+ rewriter.replaceOpWithNewOp <spirv::VectorExtractDynamicOp>(
166+ extractOp, dstType, adaptor.getVector (),
167+ adaptor.getDynamicPosition ()[0 ]);
175168 return success ();
176169 }
177170};
@@ -249,9 +242,20 @@ struct VectorInsertOpConvert final
249242 return success ();
250243 }
251244
252- int32_t id = getFirstIntValue (insertOp.getMixedPosition ());
253- rewriter.replaceOpWithNewOp <spirv::CompositeInsertOp>(
254- insertOp, adaptor.getSource (), adaptor.getDest (), id);
245+ std::optional<int64_t > id =
246+ getConstantIntValue (insertOp.getMixedPosition ()[0 ]);
247+
248+ // rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
249+ // insertOp, adaptor.getSource(), adaptor.getDest(), id);
250+ // return success();
251+
252+ if (id.has_value ())
253+ rewriter.replaceOpWithNewOp <spirv::CompositeInsertOp>(
254+ insertOp, adaptor.getSource (), adaptor.getDest (), id.value ());
255+ else
256+ rewriter.replaceOpWithNewOp <spirv::VectorInsertDynamicOp>(
257+ insertOp, insertOp.getDest (), adaptor.getSource (),
258+ adaptor.getDynamicPosition ()[0 ]);
255259 return success ();
256260 }
257261};
0 commit comments