Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class FuncOp;
/// Collect a set of patterns to rewrite GlobalIdOp op within the GPU dialect.
void populateGpuGlobalIdPatterns(RewritePatternSet &patterns);

/// Collect a set of patterns to rewrite SubgroupIdOp op within the GPU
/// dialect.
void populateGpuSubgroupIdPatterns(RewritePatternSet &patterns);

/// Collect a set of patterns to rewrite shuffle ops within the GPU dialect.
void populateGpuShufflePatterns(RewritePatternSet &patterns);

Expand Down Expand Up @@ -88,6 +92,7 @@ inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
populateGpuAllReducePatterns(patterns);
populateGpuGlobalIdPatterns(patterns);
populateGpuShufflePatterns(patterns);
populateGpuSubgroupIdPatterns(patterns);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure that this doesn't end up in the SPIR-V lowerings, which seem to have an alternate approach to this op

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. SPV links get_sub_group_id function to get the subgroup id. So I think we should just remove this line inside populateGpuRewritePatterns.

side note: SPIRV uses the same calculation method to compute subgroup_id.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I eventually stripped it from this function.

}

namespace gpu {
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/GPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
Transforms/ROCDLAttachTarget.cpp
Transforms/ShuffleRewriter.cpp
Transforms/SPIRVAttachTarget.cpp
Transforms/SubgroupIdRewriter.cpp
Transforms/SubgroupReduceLowering.cpp

OBJECT
Expand Down
82 changes: 82 additions & 0 deletions mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
//===- SubgroupIdRewriter.cpp - Implementation of SugroupId rewriting ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements in-dialect rewriting of the gpu.subgroup_id op for archs
// where:
// subgroup_id = (tid.x + dim.x * (tid.y + dim.y * tid.z)) / subgroup_size
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;

namespace {
struct GpuSubgroupIdRewriter final : OpRewritePattern<gpu::SubgroupIdOp> {
using OpRewritePattern<gpu::SubgroupIdOp>::OpRewritePattern;

LogicalResult matchAndRewrite(gpu::SubgroupIdOp op,
PatternRewriter &rewriter) const override {
// Calculation of the thread's subgroup identifier.
//
// The process involves mapping the thread's 3D identifier within its
// block (b_id.x, b_id.y, b_id.z) to a 1D linear index.
// This linearization assumes a layout where the x-dimension (w_dim.x)
// varies most rapidly (i.e., it is the innermost dimension).
//
// The formula for the linearized thread index is:
// L = tid.x + dim.x * (tid.y + (dim.y * tid.z))
//
// Subsequently, the range of linearized indices [0, N_threads-1] is
// divided into consecutive, non-overlapping segments, each representing
// a subgroup of size 'subgroup_size'.
//
// Example Partitioning (N = subgroup_size):
// | Subgroup 0 | Subgroup 1 | Subgroup 2 | ... |
// | Indices 0..N-1 | Indices N..2N-1 | Indices 2N..3N-1| ... |
//
// The subgroup identifier is obtained via integer division of the
// linearized thread index by the predefined 'subgroup_size'.
//
// subgroup_id = floor( L / subgroup_size )
// = (tid.x + dim.x * (tid.y + dim.y * tid.z)) /
// subgroup_size

auto loc = op->getLoc();

Value dimX = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
Value dimY = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::y);
Value tidX = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
Value tidY = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::y);
Value tidZ = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::z);

Value dimYxIdZ = rewriter.create<index::MulOp>(loc, dimY, tidZ);
Value dimYxIdZPlusIdY = rewriter.create<index::AddOp>(loc, dimYxIdZ, tidY);
Value dimYxIdZPlusIdYTimesDimX =
rewriter.create<index::MulOp>(loc, dimX, dimYxIdZPlusIdY);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think arith:: over index:: are fine here - none of these values are at the point where the stuff that caused the index dialect to come into existence is a problem

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... Hey, why'd this get landed with index::?

Value IdXPlusDimYxIdZPlusIdYTimesDimX =
rewriter.create<index::AddOp>(loc, tidX, dimYxIdZPlusIdYTimesDimX);
Value subgroupSize = rewriter.create<gpu::SubgroupSizeOp>(
loc, rewriter.getIndexType(), /*upper_bound = */ nullptr);
Value subgroupIdOp = rewriter.create<index::DivUOp>(
loc, IdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
rewriter.replaceOp(op, {subgroupIdOp});
return success();
}
};

} // namespace

void mlir::populateGpuSubgroupIdPatterns(RewritePatternSet &patterns) {
patterns.add<GpuSubgroupIdRewriter>(patterns.getContext());
}
26 changes: 26 additions & 0 deletions mlir/test/Dialect/GPU/subgroupId-rewrite.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: mlir-opt --test-gpu-rewrite -split-input-file %s | FileCheck %s

module {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This contains both an explicit module and and implicit one -- I don't think we need both. Can we drop either one?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

// CHECK-LABEL: func.func @subgroupId
// CHECK-SAME: (%[[SZ:.*]]: index, %[[MEM:.*]]: memref<index, 1>) {
func.func @subgroupId(%sz : index, %mem: memref<index, 1>) {
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
// CHECK: %[[DIMX:.*]] = gpu.block_dim x
// CHECK-NEXT: %[[DIMY:.*]] = gpu.block_dim y
// CHECK-NEXT: %[[TIDX:.*]] = gpu.thread_id x
// CHECK-NEXT: %[[TIDY:.*]] = gpu.thread_id y
// CHECK-NEXT: %[[TIDZ:.*]] = gpu.thread_id z
// CHECK-NEXT: %[[T0:.*]] = index.mul %[[DIMY]], %[[TIDZ]]
// CHECK-NEXT: %[[T1:.*]] = index.add %[[T0]], %[[TIDY]]
// CHECK-NEXT: %[[T2:.*]] = index.mul %[[DIMX]], %[[T1]]
// CHECK-NEXT: %[[T3:.*]] = index.add %[[TIDX]], %[[T2]]
// CHECK-NEXT: %[[T4:.*]] = gpu.subgroup_size : index
// CHECK-NEXT: %[[T5:.*]] = index.divu %[[T3]], %[[T4]]
%idz = gpu.subgroup_id : index
memref.store %idz, %mem[] : memref<index, 1>
gpu.terminator
}
return
}
}