1717#include " mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1818#include " mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1919#include " mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20+ #include " mlir/Dialect/Utils/StaticValueUtils.h"
2021#include " mlir/Dialect/Vector/IR/VectorOps.h"
2122#include " mlir/IR/Attributes.h"
2223#include " mlir/IR/BuiltinAttributes.h"
@@ -40,22 +41,9 @@ using namespace mlir;
4041// / Returns the integer value from the first valid input element, assuming Value
4142// / inputs are defined by a constant index ops and Attribute inputs are integer
4243// / attributes.
43- static uint64_t getFirstIntValue (ValueRange values) {
44- return values[0 ].getDefiningOp <arith::ConstantIndexOp>().value ();
45- }
46- static uint64_t getFirstIntValue (ArrayRef<Attribute> attr) {
47- return cast<IntegerAttr>(attr[0 ]).getInt ();
48- }
4944static uint64_t getFirstIntValue (ArrayAttr attr) {
5045 return (*attr.getAsValueRange <IntegerAttr>().begin ()).getZExtValue ();
5146}
52- static uint64_t getFirstIntValue (ArrayRef<OpFoldResult> foldResults) {
53- auto attr = foldResults[0 ].dyn_cast <Attribute>();
54- if (attr)
55- return getFirstIntValue (attr);
56-
57- return getFirstIntValue (ValueRange{foldResults[0 ].get <Value>()});
58- }
5947
6048// / Returns the number of bits for the given scalar/vector type.
6149static int getNumBits (Type type) {
@@ -157,9 +145,6 @@ struct VectorExtractOpConvert final
157145 LogicalResult
158146 matchAndRewrite (vector::ExtractOp extractOp, OpAdaptor adaptor,
159147 ConversionPatternRewriter &rewriter) const override {
160- if (extractOp.hasDynamicPosition ())
161- return failure ();
162-
163148 Type dstType = getTypeConverter ()->convertType (extractOp.getType ());
164149 if (!dstType)
165150 return failure ();
@@ -169,9 +154,15 @@ struct VectorExtractOpConvert final
169154 return success ();
170155 }
171156
172- int32_t id = getFirstIntValue (extractOp.getMixedPosition ());
173- rewriter.replaceOpWithNewOp <spirv::CompositeExtractOp>(
174- extractOp, adaptor.getVector (), id);
157+ if (std::optional<int64_t > id =
158+ getConstantIntValue (extractOp.getMixedPosition ()[0 ]))
159+ rewriter.replaceOpWithNewOp <spirv::CompositeExtractOp>(
160+ extractOp, dstType, adaptor.getVector (),
161+ rewriter.getI32ArrayAttr (id.value ()));
162+ else
163+ rewriter.replaceOpWithNewOp <spirv::VectorExtractDynamicOp>(
164+ extractOp, dstType, adaptor.getVector (),
165+ adaptor.getDynamicPosition ()[0 ]);
175166 return success ();
176167 }
177168};
@@ -249,9 +240,14 @@ struct VectorInsertOpConvert final
249240 return success ();
250241 }
251242
252- int32_t id = getFirstIntValue (insertOp.getMixedPosition ());
253- rewriter.replaceOpWithNewOp <spirv::CompositeInsertOp>(
254- insertOp, adaptor.getSource (), adaptor.getDest (), id);
243+ if (std::optional<int64_t > id =
244+ getConstantIntValue (insertOp.getMixedPosition ()[0 ]))
245+ rewriter.replaceOpWithNewOp <spirv::CompositeInsertOp>(
246+ insertOp, adaptor.getSource (), adaptor.getDest (), id.value ());
247+ else
248+ rewriter.replaceOpWithNewOp <spirv::VectorInsertDynamicOp>(
249+ insertOp, insertOp.getDest (), adaptor.getSource (),
250+ adaptor.getDynamicPosition ()[0 ]);
255251 return success ();
256252 }
257253};
0 commit comments