@@ -137,6 +137,26 @@ struct VectorBroadcastConvert final
137137 }
138138};
139139
140+ // SPIR-V does not have a concept of a poison index for certain instructions,
141+ // which creates a UB hazard when lowering from otherwise equivalent Vector
142+ // dialect instructions, because this index will be considered out-of-bounds.
143+ // To avoid this, this function implements a dynamic sanitization, arbitrarily
144+ // choosing to replace the poison index with index 0 (always in-bounds).
145+ static Value sanitizeDynamicIndex (ConversionPatternRewriter &rewriter,
146+ Location loc, Value dynamicIndex,
147+ int64_t kPoisonIndex ) {
148+ Value poisonIndex = rewriter.create <spirv::ConstantOp>(
149+ loc, dynamicIndex.getType (),
150+ rewriter.getIntegerAttr (dynamicIndex.getType (), kPoisonIndex ));
151+ Value cmpResult =
152+ rewriter.create <spirv::IEqualOp>(loc, dynamicIndex, poisonIndex);
153+ Value sanitizedIndex = rewriter.create <spirv::SelectOp>(
154+ loc, cmpResult,
155+ spirv::ConstantOp::getZero (dynamicIndex.getType (), loc, rewriter),
156+ dynamicIndex);
157+ return sanitizedIndex;
158+ }
159+
140160struct VectorExtractOpConvert final
141161 : public OpConversionPattern<vector::ExtractOp> {
142162 using OpConversionPattern::OpConversionPattern;
@@ -154,14 +174,26 @@ struct VectorExtractOpConvert final
154174 }
155175
156176 if (std::optional<int64_t > id =
157- getConstantIntValue (extractOp.getMixedPosition ()[0 ]))
158- rewriter.replaceOpWithNewOp <spirv::CompositeExtractOp>(
159- extractOp, dstType, adaptor.getVector (),
160- rewriter.getI32ArrayAttr (id.value ()));
161- else
177+ getConstantIntValue (extractOp.getMixedPosition ()[0 ])) {
178+ // TODO: It would be better to apply the ub.poison folding for this case
179+ // unconditionally, and have a specific SPIR-V lowering for it,
180+ // rather than having to handle it here.
181+ if (id == vector::ExtractOp::kPoisonIndex ) {
182+ // Arbitrary choice of poison result, intended to stick out.
183+ Value zero =
184+ spirv::ConstantOp::getZero (dstType, extractOp.getLoc (), rewriter);
185+ rewriter.replaceOp (extractOp, zero);
186+ } else
187+ rewriter.replaceOpWithNewOp <spirv::CompositeExtractOp>(
188+ extractOp, dstType, adaptor.getVector (),
189+ rewriter.getI32ArrayAttr (id.value ()));
190+ } else {
191+ Value sanitizedIndex = sanitizeDynamicIndex (
192+ rewriter, extractOp.getLoc (), adaptor.getDynamicPosition ()[0 ],
193+ vector::ExtractOp::kPoisonIndex );
162194 rewriter.replaceOpWithNewOp <spirv::VectorExtractDynamicOp>(
163- extractOp, dstType, adaptor.getVector (),
164- adaptor. getDynamicPosition ()[ 0 ]);
195+ extractOp, dstType, adaptor.getVector (), sanitizedIndex);
196+ }
165197 return success ();
166198 }
167199};
@@ -266,13 +298,25 @@ struct VectorInsertOpConvert final
266298 }
267299
268300 if (std::optional<int64_t > id =
269- getConstantIntValue (insertOp.getMixedPosition ()[0 ]))
270- rewriter.replaceOpWithNewOp <spirv::CompositeInsertOp>(
271- insertOp, adaptor.getSource (), adaptor.getDest (), id.value ());
272- else
301+ getConstantIntValue (insertOp.getMixedPosition ()[0 ])) {
302+ // TODO: It would be better to apply the ub.poison folding for this case
303+ // unconditionally, and have a specific SPIR-V lowering for it,
304+ // rather than having to handle it here.
305+ if (id == vector::InsertOp::kPoisonIndex ) {
306+ // Arbitrary choice of poison result, intended to stick out.
307+ Value zero = spirv::ConstantOp::getZero (insertOp.getDestVectorType (),
308+ insertOp.getLoc (), rewriter);
309+ rewriter.replaceOp (insertOp, zero);
310+ } else
311+ rewriter.replaceOpWithNewOp <spirv::CompositeInsertOp>(
312+ insertOp, adaptor.getSource (), adaptor.getDest (), id.value ());
313+ } else {
314+ Value sanitizedIndex = sanitizeDynamicIndex (
315+ rewriter, insertOp.getLoc (), adaptor.getDynamicPosition ()[0 ],
316+ vector::InsertOp::kPoisonIndex );
273317 rewriter.replaceOpWithNewOp <spirv::VectorInsertDynamicOp>(
274- insertOp, insertOp.getDest (), adaptor.getSource (),
275- adaptor. getDynamicPosition ()[ 0 ]);
318+ insertOp, insertOp.getDest (), adaptor.getSource (), sanitizedIndex);
319+ }
276320 return success ();
277321 }
278322};
0 commit comments