Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
464 changes: 461 additions & 3 deletions llvm/docs/NVPTXUsage.rst

Large diffs are not rendered by default.

213 changes: 210 additions & 3 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,76 @@ class NVVM_TCGEN05_LDST_ACCESS_SIZE<string Shape, int Num> {
true : llvm_void_ty);
}

class NVVM_TCGEN05_MMA_BASE<string Space> {
LLVMType a_operand_type = !if(!eq(Space, "tensor"),
llvm_tmem_ptr_ty, llvm_i64_ty);
list<LLVMType> common_args = [llvm_tmem_ptr_ty, // d
a_operand_type, // a
llvm_i64_ty, // b
llvm_i32_ty, // idesc
llvm_i1_ty]; // enable_input_d
list<IntrinsicProperty> common_intr_props = !listconcat(
[IntrArgMemOnly, WriteOnly<ArgIndex<0>>],
!if(!eq(Space, "tensor"), [ReadOnly<ArgIndex<1>>], [])
);
}

class NVVM_TCGEN05_MMA<bit Sp, string Space,
bit AShift, bit ScaleInputD>:
NVVM_TCGEN05_MMA_BASE<Space> {
string intr = "llvm.nvvm.tcgen05.mma"
# !if(!eq(Sp, 1), ".sp", "")
# "." # Space
# !if(!eq(ScaleInputD, 1), ".scale_d", "")
# !if(!eq(AShift, 1), ".ashift", "");
string record = !subst(".", "_", !subst("llvm.", "int_", intr));
}

class NVVM_TCGEN05_MMA_BLOCKSCALE<bit Sp, string Space,
string Kind, string ScaleVecSize>:
NVVM_TCGEN05_MMA_BASE<Space> {
string intr = "llvm.nvvm.tcgen05.mma"
# !if(!eq(Sp, 1), ".sp", "")
# "." # Space
# "." # Kind # ScaleVecSize
# ".block_scale";
string record = !subst(".", "_", !subst("llvm.", "int_", intr));
}

class NVVM_TCGEN05_MMA_WS<bit Sp, string Space, bit ZeroColMask>:
NVVM_TCGEN05_MMA_BASE<Space> {
string intr = "llvm.nvvm.tcgen05.mma.ws"
# !if(!eq(Sp, 1), ".sp", "")
# "." # Space
# !if(!eq(ZeroColMask, 1), ".zero_col_mask", "");
string record = !subst(".", "_", !subst("llvm.", "int_", intr));
}

class NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE<bit Sp, string Space,
int CtaGroup, bit AShift,
bit ScaleInputD>:
NVVM_TCGEN05_MMA_BASE<Space> {
string intr = "llvm.nvvm.tcgen05.mma"
# !if(!eq(Sp, 1), ".sp", "")
# "." # Space
# !if(!eq(ScaleInputD, 1), ".scale_d", "")
# ".disable_output_lane.cg" # CtaGroup
# !if(!eq(AShift, 1), ".ashift", "");
string record = !subst(".", "_", !subst("llvm.", "int_", intr));
}

class NVVM_TCGEN05_MMA_BLOCKSCALE_SUPPORTED<string Kind, string ScaleVecSize> {
bit ret = !cond(
!and(!eq(Kind, "mxf8f6f4"), !eq(ScaleVecSize, "")) : true,
!and(!eq(Kind, "mxf4"), !eq(ScaleVecSize, "")) : true,
!and(!eq(Kind, "mxf4nvf4"), !eq(ScaleVecSize, ".block16")) : true,
!and(!eq(Kind, "mxf4"), !eq(ScaleVecSize, ".block32")) : true,
!and(!eq(Kind, "mxf4nvf4"), !eq(ScaleVecSize, ".block32")) : true,
!and(!eq(Kind, "mxf8f6f4"), !eq(ScaleVecSize, ".block32")) : true,
true: false
);
}

class TexVector<string name, list<LLVMType> types> {
string Name = name;
list<LLVMType> Types = types;
Expand Down Expand Up @@ -2070,13 +2140,15 @@ def int_nvvm_exit : NVVMBuiltin,
class DefaultAttrsIntrinsicFlags<list<LLVMType> ret_types,
list<LLVMType> param_types,
list<LLVMType> flags,
list<IntrinsicProperty> intr_properties>
list<IntrinsicProperty> intr_properties,
string name = "">
: DefaultAttrsIntrinsic<
ret_types,
!listconcat(param_types, flags),
!listconcat(intr_properties,
!foreach(i, !range(flags),
ImmArg<ArgIndex<!add(i, !size(param_types))>>))>;
ImmArg<ArgIndex<!add(i, !size(param_types))>>)),
name>;

// TMA Tensor Copy Intrinsics: S2G -> From Shared to Global memory variants
foreach dim = 1...5 in {
Expand Down Expand Up @@ -2464,4 +2536,139 @@ def int_nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_ # dim
"llvm.nvvm.clusterlaunchcontrol.query_cancel.get_first_ctaid." # dim>;
}

} // let TargetPrefix = "nvvm"
//
// tcgen05.mma intrinsics
//

foreach sp = [0, 1] in {
foreach space = ["tensor", "shared"] in {
foreach scale_d = [0, 1] in {
foreach ashift = !if(!eq(space, "tensor"), [0, 1], [0]) in {
defvar mma = NVVM_TCGEN05_MMA<sp, space, ashift, scale_d>;
defvar args = !listconcat(
mma.common_args,
!if(!eq(sp, 1), [llvm_tmem_ptr_ty], []),
!if(!eq(scale_d, 1), [llvm_i64_ty], [])
);
defvar flags = [llvm_i32_ty, // kind
llvm_i32_ty, // cta_group
llvm_i32_ty]; // collector_usage_a
defvar nargs = !size(args);
defvar scale_d_imm = ArgIndex<!sub(nargs, 1)>;
defvar intrinsic_properties = !listconcat(
mma.common_intr_props,
!if(!eq(scale_d, 1),
[ImmArg<scale_d_imm>, Range<scale_d_imm, 0, 16>], []),
[Range<ArgIndex<nargs>, 0, !if(!eq(scale_d, 1), 2, 4)>,
Range<ArgIndex<!add(nargs, 1)>, 1, 3>,
Range<ArgIndex<!add(nargs, 2)>, 0, !if(!eq(ashift, 1), 2, 4)>]
);

def mma.record:
DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
mma.intr>;
}
}
}
}

//
// tcgen05.mma disable_output_lane intrinsics
//
foreach sp = [0, 1] in {
foreach space = ["tensor", "shared"] in {
foreach cta_group = [1, 2] in {
foreach scale_d = [0, 1] in {
foreach ashift = !if(!eq(space, "tensor"), [0, 1], [0]) in {
defvar mma = NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE<
sp, space, cta_group, ashift, scale_d>;
defvar disable_output_lane_type =
!if(!eq(cta_group, 1), llvm_v4i32_ty, llvm_v8i32_ty);
defvar args = !listconcat(
mma.common_args,
!if(!eq(sp, 1), [llvm_tmem_ptr_ty], []),
!if(!eq(scale_d, 1), [llvm_i64_ty], []),
[disable_output_lane_type]
);
defvar flags = [llvm_i32_ty, // kind_flag
llvm_i32_ty]; // collector_usage_a_flag
defvar nargs = !size(args);
defvar scale_d_flag = ArgIndex<!sub(nargs, 2)>;
defvar scale_d_imm_range = [ImmArg<scale_d_flag>, Range<scale_d_flag, 0, 16>];
defvar intrinsic_properties = !listconcat(
mma.common_intr_props,
!if(!eq(scale_d, 1), scale_d_imm_range, []),
[Range<ArgIndex<nargs>, 0, !if(!eq(scale_d, 1), 2, 4)>,
Range<ArgIndex<!add(nargs, 1)>, 0, !if(!eq(ashift, 1), 2, 4)>]
);

def mma.record: DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
mma.intr>;
} // ashift
} // scale_d
} // cta_group
} // space
} // sp

//
// tcgen05.mma block_scale intrinsics
//
foreach sp = [0, 1] in {
foreach space = ["tensor", "shared"] in {
foreach kind = ["mxf8f6f4", "mxf4", "mxf4nvf4"] in {
foreach scale_vec_size = ["", ".block16", ".block32"] in {
defvar mma = NVVM_TCGEN05_MMA_BLOCKSCALE<sp, space, kind, scale_vec_size>;
defvar cta_group = ArgIndex<!if(!eq(sp, 1), 8, 7)>;
defvar collector_usage = ArgIndex<!if(!eq(sp, 1), 9, 8)>;

if NVVM_TCGEN05_MMA_BLOCKSCALE_SUPPORTED<kind, scale_vec_size>.ret then {
def mma.record: DefaultAttrsIntrinsicFlags<[],
!listconcat(mma.common_args,
!if(!eq(sp, 1),
[llvm_tmem_ptr_ty], []), // spmetadata
[llvm_tmem_ptr_ty, // scale a
llvm_tmem_ptr_ty]), // scale b
// flags
[llvm_i32_ty, // cta_group
llvm_i32_ty], // collector_usage_a
!listconcat(mma.common_intr_props,
[Range<cta_group, 1, 3>,
Range<collector_usage, 0, 4>]),
mma.intr>;
}
}
}
}
}

//
// tcgen05.mma ws intrinsics
//
foreach sp = [0, 1] in {
foreach space = ["tensor", "shared"] in {
foreach zero_col_mask = [0, 1] in {
defvar mma = NVVM_TCGEN05_MMA_WS<sp, space, zero_col_mask>;
defvar args = !listconcat(
mma.common_args,
!if(!eq(sp, 1), [llvm_tmem_ptr_ty], []),
!if(!eq(zero_col_mask, 1), [llvm_i64_ty], [])
);
defvar flags = [llvm_i32_ty, // kind
llvm_i32_ty, // collector_buffer_b
llvm_i32_ty]; // collector_usage_b_op
defvar nargs = !size(args);
defvar intrinsic_properties = !listconcat(
mma.common_intr_props,
[Range<ArgIndex<nargs>, 0, 4>,
Range<ArgIndex<!add(nargs, 1)>, 0, 4>,
Range<ArgIndex<!add(nargs, 2)>, 0, 4>]
);

def mma.record:
DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
mma.intr>;
}
}
}

} // let TargetPrefix = "nvvm"
9 changes: 9 additions & 0 deletions llvm/include/llvm/IR/NVVMIntrinsicUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ enum class CTAGroupKind : uint8_t {
CG_2 = 2, // cta_group::2 modifier
};

enum class Tcgen05MMAKind : uint8_t { F16 = 0, TF32 = 1, F8F6F4 = 2, I8 = 3 };

enum class Tcgen05CollectorUsageOp : uint8_t {
DISCARD = 0,
LASTUSE = 1,
FILL = 2,
USE = 3,
};

inline bool FPToIntegerIntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
case Intrinsic::nvvm_f2i_rm_ftz:
Expand Down
Loading
Loading