Skip to content

Commit 564afa1

Browse files
jiawenliu64meta-codesync[bot]
authored andcommitted
Improve general shape performance for fprop cutlass grouped gemm (pytorch#4988)
Summary: Pull Request resolved: pytorch#4988 X-link: https://github.com/facebookresearch/FBGEMM/pull/2003 Improve general shape performance for fprop cutlass grouped gemm Reviewed By: q10, cthi Differential Revision: D84194717
1 parent e22df7f commit 564afa1

22 files changed

+689
-742
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu

Lines changed: 570 additions & 119 deletions
Large diffs are not rendered by default.

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_4_1_1_9_t.cu renamed to fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_4_1_9_t.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,17 @@
1010

1111
namespace fbgemm_gpu {
1212

13-
at::Tensor bf16bf16bf16_grouped_128_256_128_4_1_1_9_t(
13+
at::Tensor bf16bf16bf16_grouped_128_128_128_2_4_1_9_t(
1414
at::Tensor X, // BF16
1515
at::Tensor W, // BF16
1616
at::Tensor output,
1717
std::optional<at::Tensor> zero_start_index_M,
1818
std::optional<at::Tensor> M_sizes) {
19-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 256, 128, 4, 1, 1, true>(
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 2, 4, 1, true>(
2020
X, W, output, zero_start_index_M, M_sizes);
2121
}
2222

23-
at::Tensor bf16bf16bf16_grouped_128_256_128_4_1_1_9_t(
23+
at::Tensor bf16bf16bf16_grouped_128_128_128_2_4_1_9_t(
2424
at::TensorList X, // BF16
2525
at::TensorList W, // BF16
2626
at::Tensor output,
@@ -29,11 +29,11 @@ at::Tensor bf16bf16bf16_grouped_128_256_128_4_1_1_9_t(
2929
return bf16bf16bf16_grouped_impl<
3030
at::TensorList,
3131
128,
32-
256,
3332
128,
33+
128,
34+
2,
3435
4,
3536
1,
36-
1,
3737
true>(X, W, output, zero_start_index_M, M_sizes);
3838
}
3939

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_4_1_1_9_f.cu

Lines changed: 0 additions & 40 deletions
This file was deleted.

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_4_1_1_9_t.cu renamed to fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_4_2_1_9_t.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,17 @@
1010

1111
namespace fbgemm_gpu {
1212

13-
at::Tensor bf16bf16bf16_grouped_128_128_128_4_1_1_9_t(
13+
at::Tensor bf16bf16bf16_grouped_128_128_128_4_2_1_9_t(
1414
at::Tensor X, // BF16
1515
at::Tensor W, // BF16
1616
at::Tensor output,
1717
std::optional<at::Tensor> zero_start_index_M,
1818
std::optional<at::Tensor> M_sizes) {
19-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 4, 1, 1, true>(
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 4, 2, 1, true>(
2020
X, W, output, zero_start_index_M, M_sizes);
2121
}
2222

23-
at::Tensor bf16bf16bf16_grouped_128_128_128_4_1_1_9_t(
23+
at::Tensor bf16bf16bf16_grouped_128_128_128_4_2_1_9_t(
2424
at::TensorList X, // BF16
2525
at::TensorList W, // BF16
2626
at::Tensor output,
@@ -32,7 +32,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_4_1_1_9_t(
3232
128,
3333
128,
3434
4,
35-
1,
35+
2,
3636
1,
3737
true>(X, W, output, zero_start_index_M, M_sizes);
3838
}

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_1_1_1_9_t.cu

Lines changed: 0 additions & 40 deletions
This file was deleted.

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_2_1_1_9_f.cu renamed to fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_1_2_1_9_f.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,29 @@
1010

1111
namespace fbgemm_gpu {
1212

13-
at::Tensor bf16bf16bf16_grouped_256_128_128_2_1_1_9_f(
13+
at::Tensor bf16bf16bf16_grouped_128_256_128_1_2_1_9_f(
1414
at::Tensor X, // BF16
1515
at::Tensor W, // BF16
1616
at::Tensor output,
1717
std::optional<at::Tensor> zero_start_index_M,
1818
std::optional<at::Tensor> M_sizes) {
19-
return bf16bf16bf16_grouped_impl<at::Tensor, 256, 128, 128, 2, 1, 1, false>(
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 256, 128, 1, 2, 1, false>(
2020
X, W, output, zero_start_index_M, M_sizes);
2121
}
2222

23-
at::Tensor bf16bf16bf16_grouped_256_128_128_2_1_1_9_f(
23+
at::Tensor bf16bf16bf16_grouped_128_256_128_1_2_1_9_f(
2424
at::TensorList X, // BF16
2525
at::TensorList W, // BF16
2626
at::Tensor output,
2727
std::optional<at::Tensor> zero_start_index_M,
2828
std::optional<at::Tensor> M_sizes) {
2929
return bf16bf16bf16_grouped_impl<
3030
at::TensorList,
31-
256,
3231
128,
32+
256,
3333
128,
34-
2,
3534
1,
35+
2,
3636
1,
3737
false>(X, W, output, zero_start_index_M, M_sizes);
3838
}

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_2_1_1_9_t.cu

Lines changed: 0 additions & 40 deletions
This file was deleted.

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_4_1_1_9_f.cu

Lines changed: 0 additions & 40 deletions
This file was deleted.

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_16_128_2_1_1_9_f.cu renamed to fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_1_2_1_9_f.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,17 @@
1010

1111
namespace fbgemm_gpu {
1212

13-
at::Tensor bf16bf16bf16_grouped_128_16_128_2_1_1_9_f(
13+
at::Tensor bf16bf16bf16_grouped_128_32_128_1_2_1_9_f(
1414
at::Tensor X, // BF16
1515
at::Tensor W, // BF16
1616
at::Tensor output,
1717
std::optional<at::Tensor> zero_start_index_M,
1818
std::optional<at::Tensor> M_sizes) {
19-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 16, 128, 2, 1, 1, false>(
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 32, 128, 1, 2, 1, false>(
2020
X, W, output, zero_start_index_M, M_sizes);
2121
}
2222

23-
at::Tensor bf16bf16bf16_grouped_128_16_128_2_1_1_9_f(
23+
at::Tensor bf16bf16bf16_grouped_128_32_128_1_2_1_9_f(
2424
at::TensorList X, // BF16
2525
at::TensorList W, // BF16
2626
at::Tensor output,
@@ -29,10 +29,10 @@ at::Tensor bf16bf16bf16_grouped_128_16_128_2_1_1_9_f(
2929
return bf16bf16bf16_grouped_impl<
3030
at::TensorList,
3131
128,
32-
16,
32+
32,
3333
128,
34-
2,
3534
1,
35+
2,
3636
1,
3737
false>(X, W, output, zero_start_index_M, M_sizes);
3838
}

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_16_128_4_1_1_9_f.cu renamed to fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_1_4_1_9_f.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,17 @@
1010

1111
namespace fbgemm_gpu {
1212

13-
at::Tensor bf16bf16bf16_grouped_128_16_128_4_1_1_9_f(
13+
at::Tensor bf16bf16bf16_grouped_128_32_128_1_4_1_9_f(
1414
at::Tensor X, // BF16
1515
at::Tensor W, // BF16
1616
at::Tensor output,
1717
std::optional<at::Tensor> zero_start_index_M,
1818
std::optional<at::Tensor> M_sizes) {
19-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 16, 128, 4, 1, 1, false>(
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 32, 128, 1, 4, 1, false>(
2020
X, W, output, zero_start_index_M, M_sizes);
2121
}
2222

23-
at::Tensor bf16bf16bf16_grouped_128_16_128_4_1_1_9_f(
23+
at::Tensor bf16bf16bf16_grouped_128_32_128_1_4_1_9_f(
2424
at::TensorList X, // BF16
2525
at::TensorList W, // BF16
2626
at::Tensor output,
@@ -29,10 +29,10 @@ at::Tensor bf16bf16bf16_grouped_128_16_128_4_1_1_9_f(
2929
return bf16bf16bf16_grouped_impl<
3030
at::TensorList,
3131
128,
32-
16,
32+
32,
3333
128,
34-
4,
3534
1,
35+
4,
3636
1,
3737
false>(X, W, output, zero_start_index_M, M_sizes);
3838
}

0 commit comments

Comments
 (0)