Skip to content

Commit 3e1165f

Browse files
Revert "Revert "[Backend] Improve dot support to target FMA (#4516)"" (#3056)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent e69f985 commit 3e1165f

File tree

15 files changed

+647
-369
lines changed

15 files changed

+647
-369
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,12 +348,18 @@ SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
348348
SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
349349
Value linear, ArrayRef<unsigned> shape);
350350

351+
SmallVector<unsigned> delinearize(unsigned linear, ArrayRef<unsigned> shape,
352+
ArrayRef<unsigned> order);
353+
351354
Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
352355
ArrayRef<unsigned> shape, ArrayRef<unsigned> order);
353356

354357
Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
355358
ArrayRef<unsigned> shape);
356359

360+
size_t linearize(ArrayRef<unsigned> multiDim, ArrayRef<unsigned> shape,
361+
ArrayRef<unsigned> order);
362+
357363
Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key,
358364
StringRef content);
359365

@@ -496,6 +502,24 @@ inline Value dot(RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
496502
return ret;
497503
}
498504

505+
/// Extend 2d shared object to 3d.
506+
///
507+
/// If tensor has 3 dimensions, returns original shared object.
508+
/// If tensor shape is [M, N], return shared object describing shape [1, M, N]
509+
///
510+
/// This Function is used to simplify processing of 2d and 3d dot operands,
511+
/// particularly in the conversion of local_load operation.
512+
///
513+
/// \param rewriter
514+
/// \param loc
515+
/// \param smemObj
516+
/// \param shape shape of a tensor represented by smemObj
517+
/// \returns shared object describing 3d tensor
518+
SharedMemoryObject
519+
getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc,
520+
SharedMemoryObject smemObj,
521+
ArrayRef<int64_t> shape);
522+
499523
// -----------------------------------------------------------------------
500524
// Blocked layout indices
501525
// -----------------------------------------------------------------------

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,12 @@ void dumpHWLayout(RankedTensorType tensorType);
234234
// Return a string representation of the layout of the tensor.
235235
std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView);
236236

237+
template <typename T>
238+
llvm::SmallVector<T> expandMatrixShapeWithBatch(llvm::ArrayRef<T> s);
239+
240+
llvm::SmallVector<unsigned>
241+
expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o);
242+
237243
} // namespace gpu
238244
} // namespace triton
239245
} // namespace mlir

0 commit comments

Comments
 (0)