Skip to content

Conversation

@avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Sep 25, 2025

import triton
import triton.language as tl


@triton.jit
def add_kernel(
    x_ptr,  # *Pointer* to first input vector.
    y_ptr,  # *Pointer* to second input vector.
    output_ptr,  # *Pointer* to output vector.
    n_elements,  # Size of the vector.
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
    # NOTE: `constexpr` so it can be used as a shape value.
):
    # There are multiple 'programs' processing different data. We identify which program
    # we are here:
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    # This program will process inputs that are offset from the initial data.
    # For instance, if you had a vector of length 256 and block_size of 64, the programs
    # would each access the elements [0:64, 64:128, 128:192, 192:256].
    # Note that offsets is a list of pointers:
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to guard memory operations against out-of-bounds accesses.
    mask = offsets < n_elements
    # Load x and y from DRAM, masking out any extra elements in case the input is not a
    # multiple of the block size.
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # Write x + y back to DRAM.
    tl.store(output_ptr + offsets, output, mask=mask)
using PythonCall, Reactant

pyimport("sys").path.append(@__DIR__)
kernel = pyimport("vector_add").add_kernel

x = Reactant.to_rarray(rand(Float32, 1024));
y = Reactant.to_rarray(rand(Float32, 1024));
out = Reactant.to_rarray(zeros(Float32, 1024));

@code_hlo kernel(
    x,
    y,
    out,
    length(x),
    64;
    grid=cld(length(x), 64),
    num_warps=1,
    num_stages=3,
    hints=Dict(1 => 16),
)

@avik-pal avik-pal force-pushed the ap/triton_integration branch 2 times, most recently from a7ece19 to f776758 Compare September 27, 2025 13:17
@avik-pal
Copy link
Collaborator Author

avik-pal commented Sep 27, 2025

module @reactant_JITFunc... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  module @tt_module_0 {
    tt.func @add_kernel_call_e72661bb113efd0f(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>) attributes {noinline = false} {
      %0 = tt.get_program_id x : i32
      %c64_i32 = arith.constant 64 : i32
      %c64_i32_0 = arith.constant 64 : i32
      %1 = arith.extsi %0 : i32 to i64
      %2 = arith.extsi %c64_i32_0 : i32 to i64
      %3 = arith.muli %1, %2 : i64
      %c2147483647_i64 = arith.constant 2147483647 : i64
      %c-2147483648_i64 = arith.constant -2147483648 : i64
      %4 = arith.cmpi sle, %3, %c2147483647_i64 : i64
      %5 = arith.cmpi sge, %3, %c-2147483648_i64 : i64
      %6 = arith.andi %4, %5 : i1
      %7 = arith.muli %0, %c64_i32_0 : i32
      %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
      %9 = tt.splat %7 : i32 -> tensor<64xi32>
      %10 = arith.extsi %9 : tensor<64xi32> to tensor<64xi64>
      %11 = arith.extsi %8 : tensor<64xi32> to tensor<64xi64>
      %12 = arith.addi %10, %11 : tensor<64xi64>
      %c2147483647_i64_1 = arith.constant 2147483647 : i64
      %c-2147483648_i64_2 = arith.constant -2147483648 : i64
      %cst = arith.constant dense<2147483647> : tensor<64xi64>
      %13 = arith.cmpi sle, %12, %cst : tensor<64xi64>
      %cst_3 = arith.constant dense<-2147483648> : tensor<64xi64>
      %14 = arith.cmpi sge, %12, %cst_3 : tensor<64xi64>
      %15 = arith.andi %13, %14 : tensor<64xi1>
      %16 = arith.addi %9, %8 : tensor<64xi32>
      %c1024_i32 = arith.constant 1024 : i32
      %cst_4 = arith.constant dense<1024> : tensor<64xi32>
      %17 = arith.cmpi slt, %16, %cst_4 : tensor<64xi32>
      %18 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
      %19 = tt.addptr %18, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
      %20 = tt.load %19, %17 : tensor<64x!tt.ptr<f32>>
      %21 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
      %22 = tt.addptr %21, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
      %23 = tt.load %22, %17 : tensor<64x!tt.ptr<f32>>
      %24 = arith.addf %20, %23 : tensor<64xf32>
      %25 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
      %26 = tt.addptr %25, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
      tt.store %26, %24, %17 : tensor<64x!tt.ptr<f32>>
      tt.return
    }
  }
  func.func @main(%arg0: tensor<1024xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<1024xf32> {tf.aliasing_output = 1 : i32}, %arg2: tensor<1024xf32> {tf.aliasing_output = 2 : i32}) -> (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) {
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %2 = stablehlo.transpose %arg2, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %c = stablehlo.constant dense<16> : tensor<i64>
    %c_0 = stablehlo.constant dense<1> : tensor<i64>
    %c_1 = stablehlo.constant dense<1> : tensor<i64>
    %c_2 = stablehlo.constant dense<0> : tensor<i64>
    enzymexla.triton_call @tt_module_0::@add_kernel_call_e72661bb113efd0f blocks in(%c, %c_0, %c_1) shmem = %c_2 (%0, %1, %2) : (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) -> ()
    %3 = stablehlo.transpose %0, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %4 = stablehlo.transpose %1, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %5 = stablehlo.transpose %2, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    return %3, %4, %5 : tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>
  }
}

@avik-pal
Copy link
Collaborator Author

@avik-pal avik-pal force-pushed the ap/triton_integration branch 2 times, most recently from 95598f9 to 7f0afd8 Compare September 28, 2025 16:20
@avik-pal
Copy link
Collaborator Author

module @reactant_JITFunc... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  triton_ext.module @add_kernel_tt_module_e72661bb113efd0f {
    builtin.module @add_kernel_module_e72661bb113efd0f {
      tt.func private @add_kernel_call_e72661bb113efd0f(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>) attributes {enzymexla.memory_effects = ["read", "write"], noinline = false} {
        %0 = tt.get_program_id x : i32
        %c64_i32 = arith.constant 64 : i32
        %c64_i32_0 = arith.constant 64 : i32
        %1 = arith.extsi %0 : i32 to i64
        %2 = arith.extsi %c64_i32_0 : i32 to i64
        %3 = arith.muli %1, %2 : i64
        %c2147483647_i64 = arith.constant 2147483647 : i64
        %c-2147483648_i64 = arith.constant -2147483648 : i64
        %4 = arith.cmpi sle, %3, %c2147483647_i64 : i64
        %5 = arith.cmpi sge, %3, %c-2147483648_i64 : i64
        %6 = arith.andi %4, %5 : i1
        %7 = arith.muli %0, %c64_i32_0 : i32
        %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
        %9 = tt.splat %7 : i32 -> tensor<64xi32>
        %10 = arith.extsi %9 : tensor<64xi32> to tensor<64xi64>
        %11 = arith.extsi %8 : tensor<64xi32> to tensor<64xi64>
        %12 = arith.addi %10, %11 : tensor<64xi64>
        %c2147483647_i64_1 = arith.constant 2147483647 : i64
        %c-2147483648_i64_2 = arith.constant -2147483648 : i64
        %cst = arith.constant dense<2147483647> : tensor<64xi64>
        %13 = arith.cmpi sle, %12, %cst : tensor<64xi64>
        %cst_3 = arith.constant dense<-2147483648> : tensor<64xi64>
        %14 = arith.cmpi sge, %12, %cst_3 : tensor<64xi64>
        %15 = arith.andi %13, %14 : tensor<64xi1>
        %16 = arith.addi %9, %8 : tensor<64xi32>
        %c1024_i32 = arith.constant 1024 : i32
        %cst_4 = arith.constant dense<1024> : tensor<64xi32>
        %17 = arith.cmpi slt, %16, %cst_4 : tensor<64xi32>
        %18 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
        %19 = tt.addptr %18, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
        %20 = tt.load %19, %17 : tensor<64x!tt.ptr<f32>>
        %21 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
        %22 = tt.addptr %21, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
        %23 = tt.load %22, %17 : tensor<64x!tt.ptr<f32>>
        %24 = arith.addf %20, %23 : tensor<64xf32>
        %25 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
        %26 = tt.addptr %25, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
        tt.store %26, %24, %17 : tensor<64x!tt.ptr<f32>>
        tt.return
      }
    }
  }
  func.func @main(%arg0: tensor<1024xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<1024xf32> {tf.aliasing_output = 1 : i32}, %arg2: tensor<1024xf32> {tf.aliasing_output = 2 : i32}) -> (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %2 = stablehlo.transpose %arg2, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %c = stablehlo.constant dense<16> : tensor<i64>
    %c_0 = stablehlo.constant dense<1> : tensor<i64>
    %c_1 = stablehlo.constant dense<1> : tensor<i64>
    %c_2 = stablehlo.constant dense<0> : tensor<i64>
    triton_ext.call @add_kernel_tt_module_e72661bb113efd0f::@add_kernel_module_e72661bb113efd0f::@add_kernel_call_e72661bb113efd0f blocks in(%c, %c_0, %c_1) shmem = %c_2 (%0, %1, %2) : (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) -> ()
    %3 = stablehlo.transpose %0, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %4 = stablehlo.transpose %1, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %5 = stablehlo.transpose %2, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    return %3, %4, %5 : tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>
  }
}

@avik-pal avik-pal force-pushed the ap/triton_integration branch from 7f0afd8 to 4876110 Compare September 29, 2025 20:02
@avik-pal
Copy link
Collaborator Author

module @reactant_JITFunc... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  triton_ext.module @add_kernel_tt_module_e72661bb113efd0f {
    builtin.module @add_kernel_module_e72661bb113efd0f attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 0 : i32, ttg.target = "cuda:120", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} {
      llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
      llvm.func @add_kernel_call_e72661bb113efd0f(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"], noinline = false, nvvm.kernel = 1 : ui1, nvvm.reqntid = array<i32: 32>, ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32} {
        %0 = llvm.mlir.undef : vector<1xf32>
        %1 = llvm.mlir.constant(0 : i32) : i32
        %2 = llvm.mlir.constant(32 : i32) : i32
        %3 = llvm.mlir.constant(31 : i32) : i32
        %4 = llvm.mlir.constant(0 : index) : i32
        %5 = llvm.mlir.constant(1024 : i32) : i32
        %6 = llvm.mlir.constant(64 : i32) : i32
        %7 = llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.ctaid.x"() : () -> i32
        %8 = llvm.mul %7, %6 : i32
        %9 = nvvm.read.ptx.sreg.tid.x : i32
        %10 = llvm.and %9, %3 : i32
        %11 = llvm.shl %10, %1 : i32
        %12 = llvm.or %1, %11 : i32
        %13 = llvm.or %12, %1 : i32
        %14 = llvm.and %13, %3 : i32
        %15 = llvm.lshr %14, %1 : i32
        %16 = llvm.xor %1, %15 : i32
        %17 = llvm.xor %1, %16 : i32
        %18 = llvm.xor %17, %1 : i32
        %19 = llvm.xor %17, %2 : i32
        %20 = llvm.add %18, %4 : i32
        %21 = llvm.add %19, %4 : i32
        %22 = llvm.add %8, %20 : i32
        %23 = llvm.add %8, %21 : i32
        %24 = llvm.icmp "slt" %22, %5 : i32
        %25 = llvm.icmp "slt" %23, %5 : i32
        %26 = llvm.getelementptr %arg0[%22] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
        %27 = llvm.getelementptr %arg0[%23] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
        %28 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %26, %24 : (!llvm.ptr<1>, i1) -> i32
        %29 = llvm.bitcast %28 : i32 to vector<1xf32>
        %30 = llvm.extractelement %29[%4 : i32] : vector<1xf32>
        %31 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %27, %25 : (!llvm.ptr<1>, i1) -> i32
        %32 = llvm.bitcast %31 : i32 to vector<1xf32>
        %33 = llvm.extractelement %32[%4 : i32] : vector<1xf32>
        %34 = llvm.getelementptr %arg1[%22] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
        %35 = llvm.getelementptr %arg1[%23] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
        %36 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %34, %24 : (!llvm.ptr<1>, i1) -> i32
        %37 = llvm.bitcast %36 : i32 to vector<1xf32>
        %38 = llvm.extractelement %37[%4 : i32] : vector<1xf32>
        %39 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %35, %25 : (!llvm.ptr<1>, i1) -> i32
        %40 = llvm.bitcast %39 : i32 to vector<1xf32>
        %41 = llvm.extractelement %40[%4 : i32] : vector<1xf32>
        %42 = llvm.fadd %30, %38 : f32
        %43 = llvm.fadd %33, %41 : f32
        %44 = llvm.getelementptr %arg2[%22] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
        %45 = llvm.getelementptr %arg2[%23] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
        %46 = llvm.insertelement %42, %0[%1 : i32] : vector<1xf32>
        %47 = llvm.bitcast %46 : vector<1xf32> to i32
        %48 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b" %47, %44, %24 : (i32, !llvm.ptr<1>, i1) -> !llvm.void
        %49 = llvm.insertelement %43, %0[%1 : i32] : vector<1xf32>
        %50 = llvm.bitcast %49 : vector<1xf32> to i32
        %51 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b" %50, %45, %25 : (i32, !llvm.ptr<1>, i1) -> !llvm.void
        llvm.return
      }
    }
  }
  func.func @main(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>, %arg2: tensor<1024xf32>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
    %c = stablehlo.constant dense<0> : tensor<i64>
    %c_0 = stablehlo.constant dense<1> : tensor<i64>
    %c_1 = stablehlo.constant dense<16> : tensor<i64>
    triton_ext.call @add_kernel_tt_module_e72661bb113efd0f::@add_kernel_module_e72661bb113efd0f::@add_kernel_call_e72661bb113efd0f blocks in(%c_1, %c_0, %c_0) shmem = %c (%arg0, %arg1, %arg2) : (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) -> ()
    return
  }
}

@avik-pal avik-pal force-pushed the ap/triton_integration branch from 3315b07 to 4a9a1ce Compare October 1, 2025 21:33
@avik-pal avik-pal force-pushed the ap/triton_integration branch 2 times, most recently from 38bbe42 to 1042dfb Compare October 16, 2025 13:52
@avik-pal avik-pal force-pushed the ap/triton_integration branch 3 times, most recently from 891a5dd to 04cbf60 Compare October 16, 2025 20:54
@avik-pal avik-pal changed the base branch from main to ap/device_props_julia October 16, 2025 20:54
@avik-pal avik-pal force-pushed the ap/device_props_julia branch from 88b93d3 to 3b5b3fa Compare October 17, 2025 12:59
@avik-pal avik-pal force-pushed the ap/triton_integration branch from 5b39be4 to 1381077 Compare October 17, 2025 13:00
Comment on lines +47 to +48
BLOCK_SIZE,
num_stages=3;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
BLOCK_SIZE,
num_stages=3;
BLOCK_SIZE;
num_stages=3,

@avik-pal avik-pal force-pushed the ap/device_props_julia branch from 3b5b3fa to fd9f3c9 Compare October 18, 2025 00:23
@avik-pal avik-pal force-pushed the ap/triton_integration branch from 3ed4a70 to 9dde89f Compare October 18, 2025 00:29
@avik-pal avik-pal force-pushed the ap/device_props_julia branch 2 times, most recently from 190b57e to d5372aa Compare October 19, 2025 15:00
Base automatically changed from ap/device_props_julia to main October 21, 2025 18:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant