@@ -178,10 +178,8 @@ std::optional<DenseElementsAttr>
178178TosaReduceTransposes::transposeDenseAttribute (DenseElementsAttr input,
179179 ArrayRef<int32_t > perms) {
180180 RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType ());
181- RankedTensorType newType =
182- RankedTensorType::get (applyTOSAPermutation (oldType.getShape (), perms),
183- oldType.getElementType ());
184- size_t rank = oldType.getRank ();
181+ ArrayRef<int64_t > oldShape = oldType.getShape ();
182+ int64_t rank = oldType.getRank ();
185183
186184 // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension
187185 // 0. If not in place, something is very wrong.
@@ -190,65 +188,83 @@ TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
190188 return std::nullopt ;
191189 }
192190
193- if (input.isSplat ())
191+ auto newShape = applyTOSAPermutation (oldShape, perms);
192+ RankedTensorType newType =
193+ RankedTensorType::get (newShape, oldType.getElementType ());
194+
195+ if (input.isSplat ()) {
194196 return input.reshape (newType);
197+ }
198+
199+ auto rawData = input.getRawData ();
200+ if (!rawData.data ()) {
201+ return std::nullopt ;
202+ }
195203
196204 // The algorithm is approximately as follows:
197- // input: perms, input flat array, input tensor type
198- // (1/2) determine the strides of input/output if
199- // they were strided in row-major order. (3) adjust the strides for the
200- // input to be in the same order of indices as the output is written.
201- // (4) process dimension by dimension. example: perms 2, 0, 1; input
202- // 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] =
203- // input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust
204- // input strides to be as input[i + 12j + 4k] so we may process
205- // layer-by-layer.
206-
207- // Step 1/2: Strides for input. We ignore output since row-major and can just
208- // push_back.
209-
210- SmallVector<int64_t > originalInputStrides (rank);
211- originalInputStrides[rank - 1 ] = 1 ;
212- // index with int64_t to avoid overflow
213- for (int64_t i = rank - 2 ; i >= 0 ; i--)
214- originalInputStrides[i] =
215- originalInputStrides[i + 1 ] * oldType.getDimSize (i + 1 );
216-
217- // Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as
218- // output which is done in row-major order.
219-
220- SmallVector<int64_t > newInputStrides;
221- newInputStrides.reserve (rank);
222- for (int32_t v : perms)
223- newInputStrides.push_back (originalInputStrides[v]);
224-
225- // Step 4: Write out the transposed "flat array" dimension by dimension.
226-
227- auto inputArray = input.getValues <Attribute>();
228- SmallVector<std::pair<int64_t , int64_t >> boundsAndStrides;
229- for (size_t i = 0 ; i < rank; i++)
230- boundsAndStrides.push_back ({newType.getDimSize (i), newInputStrides[i]});
231-
232- SmallVector<Attribute> resultArray;
233- resultArray.reserve (inputArray.size ());
234-
235- std::function<void (int64_t ,
236- SmallVector<std::pair<int64_t , int64_t >>::const_iterator)>
237- processTransposeDim = [&](auto accumulatedIndex, auto it) {
238- if (it == boundsAndStrides.end ()) {
239- resultArray.push_back (inputArray[accumulatedIndex]);
240- return ;
241- }
242-
243- for (int64_t i = 0 ; i < it->first ; i++) {
244- int64_t j = accumulatedIndex + i * it->second ;
245- processTransposeDim (j, it + 1 );
246- }
247- };
248-
249- processTransposeDim (0 , boundsAndStrides.begin ());
250-
251- return DenseElementsAttr::get (newType, resultArray);
205+ // 1. Determine the strides of both input and output tensors in row-major
206+ // order
207+ // 2. Iterate through the output tensor linearly.
208+ // 3. For each output position, decompose the linear index into
209+ // multi-dimensional coordinates using output strides.
210+ // 4. Use the permutation to map output coordinates to input coordinates and
211+ // calculate the source linear index.
212+
213+ // Example: perms [2, 0, 1]; input 2x3x4; output 4x2x3
214+ // for output linear index 11: decompose to output[1][1][2]
215+ // using output strides [6,3,1]. Map to input coordinates using
216+ // perms: dim 0→2, dim 1→0, dim 2→1, giving source position
217+ // calculated as 1*inputStrides[2] + 1*inputStrides[0] + 2*inputStrides[1]
218+ // = 1*1 + 1*12 + 2*4 = 21
219+
220+ size_t elementSize = oldType.getElementTypeBitWidth () / 8 ;
221+ int64_t numElements = oldType.getNumElements ();
222+
223+ SmallVector<char > outputBuffer (numElements * elementSize);
224+ const char *inputPtr = rawData.data ();
225+ char *outputPtr = outputBuffer.data ();
226+
227+ auto calculateStrides = [](ArrayRef<int64_t > shape) -> SmallVector<int64_t > {
228+ int64_t rank = shape.size ();
229+ SmallVector<int64_t > strides (rank);
230+ strides[rank - 1 ] = 1 ;
231+ for (int64_t i = rank - 2 ; i >= 0 ; --i) {
232+ strides[i] = strides[i + 1 ] * shape[i + 1 ];
233+ }
234+ return strides;
235+ };
236+
237+ // Calculate strides for both input and output tensors
238+ SmallVector<int64_t > inputStrides = calculateStrides (oldShape);
239+ SmallVector<int64_t > outputStrides = calculateStrides (newShape);
240+
241+ auto mapCoordinates = [&](int64_t destLinearIndex) -> int64_t {
242+ int64_t tempDestIndex = destLinearIndex;
243+ int64_t sourceLinearIndex = 0 ;
244+
245+ // Decompose linear destination index into multi-dimensional
246+ // coordinates dividing by output strides.
247+ // Simultaneously map these coordinates through the permutation
248+ // to calculate the corresponding source linear index.
249+ for (auto j : llvm::seq<int64_t >(rank)) {
250+ int64_t destCoord = tempDestIndex / outputStrides[j];
251+ tempDestIndex %= outputStrides[j];
252+ sourceLinearIndex += destCoord * inputStrides[perms[j]];
253+ }
254+
255+ return sourceLinearIndex;
256+ };
257+
258+ for (auto destLinearIndex : llvm::seq<int64_t >(numElements)) {
259+ int64_t sourceLinearIndex = mapCoordinates (destLinearIndex);
260+
261+ // Copy the element from source to destination using type-agnostic byte
262+ // copying.
263+ std::memcpy (outputPtr + destLinearIndex * elementSize,
264+ inputPtr + sourceLinearIndex * elementSize, elementSize);
265+ }
266+
267+ return DenseElementsAttr::getFromRawBuffer (newType, outputBuffer);
252268}
253269
254270// The SetVector should only contain ConstOp, ReshapeOp, TransposeOp
0 commit comments