Skip to content

Commit a250418

Browse files
authored
Fix 2:4 sparsify meta registrations (#2366)
* fix 2:4 meta registrations Summary: We need to register in python for symbolic shape support, which is needed for vLLM Test Plan: Reviewers: Subscribers: Tasks: Tags: * add meta for sparse gemm
1 parent 82bc17e commit a250418

File tree

4 files changed

+54
-43
lines changed

4 files changed

+54
-43
lines changed

test/sparsity/test_activation24.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def test_sparse24_fp8_sm90_cutlass_gemm_eye(
171171
# Check MM with scale
172172
b_scale = torch.randn([1, A.shape[1]], device=eye.device, dtype=torch.float32)
173173
a_scale = torch.randn([A.shape[0], 1], device=eye.device, dtype=torch.float32)
174-
A_reconstructed = torch.ops.torchao._sparse24_fp8_sm90_cutlass_gemm(
174+
A_reconstructed = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm(
175175
A_packed, A_mdata, eye, a_scale=a_scale, b_scale=b_scale
176176
)
177177
assert torch.allclose(

torchao/csrc/cuda/activation24/sparse_gemm.cu

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,6 @@ struct SparseRowwiseKernel<cutlass::float_e4m3_t> {
132132

133133
template <>
134134
struct SparseRowwiseKernel<cutlass::bfloat16_t> {
135-
static constexpr auto kElementOutAt = at::ScalarType::BFloat16;
136-
static constexpr auto kElementAAt = at::ScalarType::BFloat16;
137-
138135
using ElementA = cutlass::bfloat16_t;
139136
using ElementB = cutlass::bfloat16_t;
140137
using ElementOut = cutlass::bfloat16_t;
@@ -209,7 +206,6 @@ struct SparseRowwiseKernel<cutlass::bfloat16_t> {
209206
using ElementE = CollectiveMainloop::ElementE;
210207
};
211208

212-
template <bool kIsMeta>
213209
Tensor _sparse24_fp8_sm90_cutlass_gemm(
214210
const Tensor& tensor_a,
215211
const Tensor& tensor_e, // metadata for `A`
@@ -221,20 +217,16 @@ Tensor _sparse24_fp8_sm90_cutlass_gemm(
221217
std::string swizzle_axis,
222218
int64_t sm_count) {
223219
std::optional<at::cuda::CUDAGuard> device_guard;
224-
if (!kIsMeta) {
225-
device_guard.emplace(tensor_a.device());
226-
}
220+
device_guard.emplace(tensor_a.device());
227221

228222
using K = SparseRowwiseKernel<cutlass::float_e4m3_t>;
229223

230224
// For now, only CC 9.x devices are supported.
231-
if (!kIsMeta) {
232-
const auto dprops = at::cuda::getCurrentDeviceProperties();
233-
TORCH_CHECK(
234-
dprops && dprops->major == 9,
235-
"_sparse24_gemm_fp8_sm90: Supported only on GPUs with "
236-
"compute capability 9.x");
237-
}
225+
const auto dprops = at::cuda::getCurrentDeviceProperties();
226+
TORCH_CHECK(
227+
dprops && dprops->major == 9,
228+
"_sparse24_gemm_fp8_sm90: Supported only on GPUs with "
229+
"compute capability 9.x");
238230

239231
// Validate layouts of input tensors.
240232
TORCH_CHECK(tensor_a.device() == tensor_b.device());
@@ -340,12 +332,7 @@ Tensor _sparse24_fp8_sm90_cutlass_gemm(
340332
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
341333
m.impl(
342334
TORCH_SELECTIVE_NAME("torchao::sparse24_fp8_sm90_cutlass_gemm"),
343-
TORCH_FN(_sparse24_fp8_sm90_cutlass_gemm<false>));
335+
TORCH_FN(_sparse24_fp8_sm90_cutlass_gemm));
344336
}
345337

346-
TORCH_LIBRARY_IMPL(torchao, Meta, m) {
347-
m.impl(
348-
TORCH_SELECTIVE_NAME("torchao::sparse24_fp8_sm90_cutlass_gemm"),
349-
TORCH_FN(_sparse24_fp8_sm90_cutlass_gemm<true>));
350-
}
351338
#endif

torchao/csrc/cuda/activation24/sparsify24.cu

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,6 @@ struct SparsifyKernelParams {
263263
};
264264
265265
template <
266-
bool kIsMeta,
267266
typename MetadataFormat,
268267
typename ElementIn,
269268
typename ElementOut,
@@ -274,10 +273,8 @@ std::tuple<at::Tensor, at::Tensor> sparse24_sm90_sparsify_specialized(
274273
std::string sp_selection_algo,
275274
std::optional<at::Tensor> scale) {
276275
std::optional<at::cuda::CUDAGuard> device_guard;
277-
if (!kIsMeta) {
278-
TORCH_CHECK(input.is_cuda(), "All tensors must be on GPU");
279-
device_guard.emplace(input.device());
280-
}
276+
TORCH_CHECK(input.is_cuda(), "All tensors must be on GPU");
277+
device_guard.emplace(input.device());
281278
282279
TORCH_CHECK(input.dim() == 2, "Can only sparsify 2d tensors");
283280
TORCH_CHECK(
@@ -306,9 +303,6 @@ std::tuple<at::Tensor, at::Tensor> sparse24_sm90_sparsify_specialized(
306303
auto launchKernel = [&](auto algo, std::string const& algo_name) {
307304
if (algo_name == sp_selection_algo) {
308305
kernel_launched = true;
309-
if (kIsMeta) {
310-
return;
311-
}
312306
using Params = SparsifyKernelParams<
313307
ElementIn,
314308
ElementOut,
@@ -347,7 +341,6 @@ struct SquaredReLU {
347341
}
348342
};
349343
350-
template <bool kIsMeta = false>
351344
std::tuple<at::Tensor, at::Tensor> sparse24_sm90_sparsify(
352345
at::Tensor input,
353346
std::string metadata_fmt,
@@ -363,7 +356,6 @@ std::tuple<at::Tensor, at::Tensor> sparse24_sm90_sparsify(
363356
using ElementIn = decltype(in_type);
364357
using ElementOut = decltype(out_type);
365358
return sparse24_sm90_sparsify_specialized<
366-
kIsMeta,
367359
decltype(mdatafmt),
368360
ElementIn,
369361
ElementOut>(input, act, sp_selection_algo, scale);
@@ -409,11 +401,5 @@ std::tuple<at::Tensor, at::Tensor> sparse24_sm90_sparsify(
409401
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
410402
m.impl(
411403
TORCH_SELECTIVE_NAME("torchao::sparse24_sm90_sparsify"),
412-
TORCH_FN(sparse24_sm90_sparsify<false>));
413-
}
414-
415-
TORCH_LIBRARY_IMPL(torchao, Meta, m) {
416-
m.impl(
417-
TORCH_SELECTIVE_NAME("torchao::sparse24_sm90_sparsify"),
418-
TORCH_FN(sparse24_sm90_sparsify<true>));
404+
TORCH_FN(sparse24_sm90_sparsify));
419405
}

torchao/ops.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -843,15 +843,39 @@ def sparse24_sm90_sparsify(
843843
)
844844

845845

846+
@register_custom_op("torchao::sparse24_sm90_sparsify")
847+
def _(
848+
input_tensor: Tensor,
849+
metadata_format: str,
850+
activation: str,
851+
algorithm: str,
852+
dtype=None,
853+
scale=None,
854+
):
855+
out_dtype = dtype if dtype is not None else input_tensor.dtype
856+
return (
857+
torch.empty(
858+
(input_tensor.shape[0], input_tensor.shape[1] // 2),
859+
dtype=out_dtype,
860+
device=input_tensor.device,
861+
),
862+
torch.empty(
863+
(input_tensor.shape[0], input_tensor.shape[1] // 8),
864+
dtype=torch.uint8,
865+
device=input_tensor.device,
866+
),
867+
)
868+
869+
846870
def sparse24_fp8_sm90_cutlass_gemm(
847871
a: Tensor,
848872
meta: Tensor,
849873
b: Tensor,
850-
a_scale: Optional[Tensor],
851-
b_scale: Optional[Tensor],
852-
swizzle_size: int,
853-
swizzle_axis: str,
854-
sm_count: int,
874+
a_scale: Optional[Tensor] = None,
875+
b_scale: Optional[Tensor] = None,
876+
swizzle_size: int = 8,
877+
swizzle_axis: str = "n",
878+
sm_count: int = 128,
855879
) -> Tensor:
856880
return torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm(
857881
a,
@@ -865,6 +889,20 @@ def sparse24_fp8_sm90_cutlass_gemm(
865889
)
866890

867891

892+
@register_custom_op("torchao::sparse24_fp8_sm90_cutlass_gemm")
893+
def _(
894+
a: Tensor,
895+
meta: Tensor,
896+
b: Tensor,
897+
a_scale: Optional[Tensor] = None,
898+
b_scale: Optional[Tensor] = None,
899+
swizzle_size: int = 8,
900+
swizzle_axis: str = "n",
901+
sm_count: int = 128,
902+
):
903+
return torch.empty((a.shape[0], b.shape[1]), dtype=torch.bfloat16, device=a.device)
904+
905+
868906
def swizzle_mm(
869907
mat1: Tensor, mat2: Tensor, mat1_is_swizzled: bool, mat2_is_swizzled: bool
870908
) -> Tensor:

0 commit comments

Comments
 (0)