Skip to content

Commit 52cf1ae

Browse files
authored
[Backend] Improve dot support to target FMA (#4516)
This PR: - Refactors FMA dot implementation - Supports dot3d in FMA path - Fixes several issues in operand offset computation - Enables small dot operands for AMD backend
1 parent 976c4e4 commit 52cf1ae

File tree

13 files changed

+573
-294
lines changed

13 files changed

+573
-294
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,12 +335,18 @@ SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
335335
SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
336336
Value linear, ArrayRef<unsigned> shape);
337337

338+
SmallVector<unsigned> delinearize(unsigned linear, ArrayRef<unsigned> shape,
339+
ArrayRef<unsigned> order);
340+
338341
Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
339342
ArrayRef<unsigned> shape, ArrayRef<unsigned> order);
340343

341344
Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
342345
ArrayRef<unsigned> shape);
343346

347+
size_t linearize(ArrayRef<unsigned> multiDim, ArrayRef<unsigned> shape,
348+
ArrayRef<unsigned> order);
349+
344350
Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key,
345351
StringRef content);
346352

@@ -495,6 +501,24 @@ inline Value dot(RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
495501
return ret;
496502
}
497503

504+
/// Extend 2d shared object to 3d.
505+
///
506+
/// If tensor has 3 dimensions, returns original shared object.
507+
/// If tensor shape is [M, N], return shared object describing shape [1, M, N]
508+
///
509+
/// This Function is used to simplify processing of 2d and 3d dot operands,
510+
/// particularly in the conversion of local_load operation.
511+
///
512+
/// \param rewriter
513+
/// \param loc
514+
/// \param smemObj
515+
/// \param shape shape of a tensor represented by smemObj
516+
/// \returns shared object describing 3d tensor
517+
SharedMemoryObject
518+
getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc,
519+
SharedMemoryObject smemObj,
520+
ArrayRef<int64_t> shape);
521+
498522
// -----------------------------------------------------------------------
499523
// Blocked layout indices
500524
// -----------------------------------------------------------------------

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,12 @@ void dumpHWLayout(RankedTensorType tensorType);
234234
// Return a string representation of the layout of the tensor.
235235
std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView);
236236

237+
template <typename T>
238+
llvm::SmallVector<T> expandMatrixShapeWithBatch(llvm::ArrayRef<T> s);
239+
240+
llvm::SmallVector<unsigned>
241+
expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o);
242+
237243
} // namespace gpu
238244
} // namespace triton
239245
} // namespace mlir

0 commit comments

Comments
 (0)