@@ -137,6 +137,33 @@ struct VectorBroadcastConvert final
137137 }
138138};
139139
140+ // SPIR-V does not have a concept of a poison index for certain instructions,
141+ // which creates a UB hazard when lowering from otherwise equivalent Vector
142+ // dialect instructions, because this index will be considered out-of-bounds.
143+ // To avoid this, this function implements a dynamic sanitization that returns
144+ // some arbitrary safe index. For power-of-two vector sizes, this uses a bitmask
145+ // (presumably more efficient), and otherwise index 0 (always in-bounds).
146+ static Value sanitizeDynamicIndex (ConversionPatternRewriter &rewriter,
147+ Location loc, Value dynamicIndex,
148+ int64_t kPoisonIndex , unsigned vectorSize) {
149+ if (llvm::isPowerOf2_32 (vectorSize)) {
150+ Value inBoundsMask = rewriter.create <spirv::ConstantOp>(
151+ loc, dynamicIndex.getType (),
152+ rewriter.getIntegerAttr (dynamicIndex.getType (), vectorSize - 1 ));
153+ return rewriter.create <spirv::BitwiseAndOp>(loc, dynamicIndex,
154+ inBoundsMask);
155+ }
156+ Value poisonIndex = rewriter.create <spirv::ConstantOp>(
157+ loc, dynamicIndex.getType (),
158+ rewriter.getIntegerAttr (dynamicIndex.getType (), kPoisonIndex ));
159+ Value cmpResult =
160+ rewriter.create <spirv::IEqualOp>(loc, dynamicIndex, poisonIndex);
161+ return rewriter.create <spirv::SelectOp>(
162+ loc, cmpResult,
163+ spirv::ConstantOp::getZero (dynamicIndex.getType (), loc, rewriter),
164+ dynamicIndex);
165+ }
166+
140167struct VectorExtractOpConvert final
141168 : public OpConversionPattern<vector::ExtractOp> {
142169 using OpConversionPattern::OpConversionPattern;
@@ -154,14 +181,26 @@ struct VectorExtractOpConvert final
154181 }
155182
156183 if (std::optional<int64_t > id =
157- getConstantIntValue (extractOp.getMixedPosition ()[0 ]))
158- rewriter.replaceOpWithNewOp <spirv::CompositeExtractOp>(
159- extractOp, dstType, adaptor.getVector (),
160- rewriter.getI32ArrayAttr (id.value ()));
161- else
184+ getConstantIntValue (extractOp.getMixedPosition ()[0 ])) {
185+ // TODO: ExtractOp::fold() already can fold a static poison index to
186+ // ub.poison; remove this once ub.poison can be converted to SPIR-V.
187+ if (id == vector::ExtractOp::kPoisonIndex ) {
188+ // Arbitrary choice of poison result, intended to stick out.
189+ Value zero =
190+ spirv::ConstantOp::getZero (dstType, extractOp.getLoc (), rewriter);
191+ rewriter.replaceOp (extractOp, zero);
192+ } else
193+ rewriter.replaceOpWithNewOp <spirv::CompositeExtractOp>(
194+ extractOp, dstType, adaptor.getVector (),
195+ rewriter.getI32ArrayAttr (id.value ()));
196+ } else {
197+ Value sanitizedIndex = sanitizeDynamicIndex (
198+ rewriter, extractOp.getLoc (), adaptor.getDynamicPosition ()[0 ],
199+ vector::ExtractOp::kPoisonIndex ,
200+ extractOp.getSourceVectorType ().getNumElements ());
162201 rewriter.replaceOpWithNewOp <spirv::VectorExtractDynamicOp>(
163- extractOp, dstType, adaptor.getVector (),
164- adaptor. getDynamicPosition ()[ 0 ]);
202+ extractOp, dstType, adaptor.getVector (), sanitizedIndex);
203+ }
165204 return success ();
166205 }
167206};
@@ -266,13 +305,25 @@ struct VectorInsertOpConvert final
266305 }
267306
268307 if (std::optional<int64_t > id =
269- getConstantIntValue (insertOp.getMixedPosition ()[0 ]))
270- rewriter.replaceOpWithNewOp <spirv::CompositeInsertOp>(
271- insertOp, adaptor.getSource (), adaptor.getDest (), id.value ());
272- else
308+ getConstantIntValue (insertOp.getMixedPosition ()[0 ])) {
309+ // TODO: ExtractOp::fold() already can fold a static poison index to
310+ // ub.poison; remove this once ub.poison can be converted to SPIR-V.
311+ if (id == vector::InsertOp::kPoisonIndex ) {
312+ // Arbitrary choice of poison result, intended to stick out.
313+ Value zero = spirv::ConstantOp::getZero (insertOp.getDestVectorType (),
314+ insertOp.getLoc (), rewriter);
315+ rewriter.replaceOp (insertOp, zero);
316+ } else
317+ rewriter.replaceOpWithNewOp <spirv::CompositeInsertOp>(
318+ insertOp, adaptor.getSource (), adaptor.getDest (), id.value ());
319+ } else {
320+ Value sanitizedIndex = sanitizeDynamicIndex (
321+ rewriter, insertOp.getLoc (), adaptor.getDynamicPosition ()[0 ],
322+ vector::InsertOp::kPoisonIndex ,
323+ insertOp.getDestVectorType ().getNumElements ());
273324 rewriter.replaceOpWithNewOp <spirv::VectorInsertDynamicOp>(
274- insertOp, insertOp.getDest (), adaptor.getSource (),
275- adaptor. getDynamicPosition ()[ 0 ]);
325+ insertOp, insertOp.getDest (), adaptor.getSource (), sanitizedIndex);
326+ }
276327 return success ();
277328 }
278329};
0 commit comments