@@ -152,15 +152,15 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
152152}
153153
154154// Compute the product of sizes in the range [lo, hi) from the sizes array.
155+ // Note: all sizes are i64.
155156static Value getProductOfSizes (ConversionPatternRewriter &rewriter,
156157 Location loc, ArrayRef<OpFoldResult> sizes,
157158 size_t lo, size_t hi) {
158- Type indexTy = rewriter. getIndexType ();
159- Value product = arith::ConstantIndexOp ::create (rewriter, loc, 1 );
159+ Value product =
160+ arith::ConstantIntOp ::create (rewriter, loc, rewriter. getI64Type () , 1 );
160161 for (size_t idx = lo; idx < hi; idx++) {
161162 OpFoldResult ofr = sizes[idx];
162163 Value sizeVal = getValueOrCreateConstantIntOp (rewriter, loc, ofr);
163- sizeVal = getValueOrCreateCastToIndexLike (rewriter, loc, indexTy, sizeVal);
164164 product = rewriter.createOrFold <arith::MulIOp>(loc, product, sizeVal);
165165 }
166166 return product;
@@ -233,6 +233,8 @@ class CreateNdDescToXeVMPattern
233233 // Generate compute chain for height (product of sizes of all but the last
234234 // dimension).
235235 baseShapeH = getProductOfSizes (rewriter, loc, mixedSizes, 0 , srcRank - 1 );
236+ baseShapeH = getValueOrCreateCastToIndexLike (rewriter, loc, payloadElemTy,
237+ baseShapeH);
236238 }
237239 if (sourceMemrefTy) {
238240 // Cast index to i64.
0 commit comments