@@ -46,6 +46,69 @@ static void setInsertionPointToStart(OpBuilder &builder, Value val) {
4646 }
4747}
4848
49+ OpFoldResult computeMemRefSpan (Value memref, OpBuilder &builder) {
50+ Location loc = memref.getLoc ();
51+ MemRefType type = cast<MemRefType>(memref.getType ());
52+ ArrayRef<int64_t > shape = type.getShape ();
53+
54+ // Check for empty memref
55+ if (type.hasStaticShape () &&
56+ llvm::any_of (shape, [](int64_t dim) { return dim == 0 ; })) {
57+ return builder.getIndexAttr (0 );
58+ }
59+
60+ // Get strides of the memref
61+ SmallVector<int64_t , 4 > strides;
62+ int64_t offset;
63+ if (failed (type.getStridesAndOffset (strides, offset))) {
64+ // Cannot extract strides, return a dynamic value
65+ return Value ();
66+ }
67+
68+ // Static case: compute at compile time if possible
69+ if (type.hasStaticShape ()) {
70+ int64_t span = 0 ;
71+ for (unsigned i = 0 ; i < type.getRank (); ++i) {
72+ span += (shape[i] - 1 ) * strides[i];
73+ }
74+ return builder.getIndexAttr (span);
75+ }
76+
77+ // Dynamic case: emit IR to compute at runtime
78+ Value result = builder.create <arith::ConstantIndexOp>(loc, 0 );
79+
80+ for (unsigned i = 0 ; i < type.getRank (); ++i) {
81+ // Get dimension size
82+ Value dimSize;
83+ if (shape[i] == ShapedType::kDynamic ) {
84+ dimSize = builder.create <memref::DimOp>(loc, memref, i);
85+ } else {
86+ dimSize = builder.create <arith::ConstantIndexOp>(loc, shape[i]);
87+ }
88+
89+ // Compute (dim - 1)
90+ Value one = builder.create <arith::ConstantIndexOp>(loc, 1 );
91+ Value dimMinusOne = builder.create <arith::SubIOp>(loc, dimSize, one);
92+
93+ // Get stride
94+ Value stride;
95+ if (strides[i] == ShapedType::kDynamicStrideOrOffset ) {
96+ // For dynamic strides, need to extract from memref descriptor
97+ // This would require runtime support, possibly using extractStride
98+ // As a placeholder, return a dynamic value
99+ return Value ();
100+ } else {
101+ stride = builder.create <arith::ConstantIndexOp>(loc, strides[i]);
102+ }
103+
104+ // Add (dim - 1) * stride to result
105+ Value term = builder.create <arith::MulIOp>(loc, dimMinusOne, stride);
106+ result = builder.create <arith::AddIOp>(loc, result, term);
107+ }
108+
109+ return result;
110+ }
111+
49112static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>, OpFoldResult,
50113 OpFoldResult>
51114getFlatOffsetAndStrides (OpBuilder &rewriter, Location loc, Value source,
@@ -102,8 +165,9 @@ getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
102165 affine::makeComposedFoldedAffineApply (rewriter, loc, expr, values);
103166
104167 // Compute collapsed size: (the outmost stride * outmost dimension).
105- SmallVector<OpFoldResult> ops{origStrides.front (), outmostDim};
106- OpFoldResult collapsedSize = affine::computeProduct (loc, rewriter, ops);
168+ // SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
169+ // OpFoldResult collapsedSize = affine::computeProduct(loc, rewriter, ops);
170+ OpFoldResult collapsedSize = computeMemRefSpan (source, rewriter);
107171
108172 return {newExtractStridedMetadata.getBaseBuffer (), linearizedIndex,
109173 origStrides, origOffset, collapsedSize};
0 commit comments