Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ using namespace mlir::nvgpu;

#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"

void nvgpu::NVGPUDialect::initialize() {
void NVGPUDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
Expand All @@ -42,7 +42,7 @@ void nvgpu::NVGPUDialect::initialize() {
>();
}

bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
bool NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
if (!memorySpace)
return false;
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
Expand All @@ -52,7 +52,7 @@ bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
return false;
}

bool nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
bool NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
Attribute memorySpace = type.getMemorySpace();
return isSharedMemoryAddressSpace(memorySpace);
}
Expand Down Expand Up @@ -140,7 +140,6 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
TypedValue<VectorType> matrixC,
const std::array<int64_t, 3> &mmaShape,
bool tf32Enabled, bool sparse = false) {

// The verification for mma.sync covering various shapes and data types is
// based on the fundamental tensor core shape.

Expand Down Expand Up @@ -292,7 +291,6 @@ LogicalResult MmaSparseSyncOp::verify() {
// NVGPU_LdMatrixOp
//===----------------------------------------------------------------------===//
LogicalResult LdMatrixOp::verify() {

// ldmatrix reads data from source in shared memory
auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().getType());

Expand Down Expand Up @@ -345,7 +343,7 @@ LogicalResult LdMatrixOp::verify() {
// NVGPU_TmaAsyncLoadOp
//===----------------------------------------------------------------------===//

unsigned getSwizzleBytes(TensorMapSwizzleKind kind) {
static unsigned getSwizzleBytes(TensorMapSwizzleKind kind) {
switch (kind) {
case TensorMapSwizzleKind::SWIZZLE_32B:
return 32;
Expand All @@ -359,7 +357,7 @@ unsigned getSwizzleBytes(TensorMapSwizzleKind kind) {
}

std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
Operation *op, nvgpu::TensorMapDescriptorType descType,
Operation *op, TensorMapDescriptorType descType,
std::optional<MemRefType> memrefType = std::nullopt) {
MemRefType descMemref = descType.getTensor();
// Limitation
Expand Down Expand Up @@ -655,8 +653,7 @@ LogicalResult WarpgroupMmaStoreOp::verify() {
//===----------------------------------------------------------------------===//

LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {

nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType();
WarpgroupAccumulatorType accType = getMatrixC().getType();
int64_t sizeM = accType.getFragmented().getDimSize(0);
int64_t sizeN = accType.getFragmented().getDimSize(1);
Type elemType = accType.getFragmented().getElementType();
Expand Down
Loading