Skip to content

Commit a20d826

Browse files
authored
Refactor Xetla Gemm code to reduce compile time (#389)
1 parent e0d1edc commit a20d826

File tree

10 files changed

+551
-252
lines changed

10 files changed

+551
-252
lines changed

xla/service/gpu/onednn_matmul_utils.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ RunXetlaGemm(se::gpu::GpuStreamHandle handle, const MatrixDescriptor& lhs,
264264
policy
265265
.add_epilogue(
266266
c_data,
267-
::gpu::xetla::XetlaGemmKernel<InputT>::EpilogueType::RES_ADD)
267+
::gpu::xetla::EpilogueType::RES_ADD)
268268
.build();
269269
} else {
270270
return true;
@@ -283,13 +283,13 @@ RunXetlaGemm(se::gpu::GpuStreamHandle handle, const MatrixDescriptor& lhs,
283283
.add_matrix_b(rhs)
284284
.add_epilogue(
285285
bias_data,
286-
::gpu::xetla::XetlaGemmKernel<InputT>::EpilogueType::BIAS)
286+
::gpu::xetla::EpilogueType::BIAS)
287287
.build();
288288
if (fabs(beta) - 0.0f > 1e-6) {
289289
policy
290290
.add_epilogue(
291291
c_data,
292-
::gpu::xetla::XetlaGemmKernel<InputT>::EpilogueType::RES_ADD,
292+
::gpu::xetla::EpilogueType::RES_ADD,
293293
beta)
294294
.build();
295295
}
@@ -306,7 +306,7 @@ RunXetlaGemm(se::gpu::GpuStreamHandle handle, const MatrixDescriptor& lhs,
306306
.add_matrix_b(rhs)
307307
.add_epilogue(
308308
nullptr,
309-
::gpu::xetla::XetlaGemmKernel<InputT>::EpilogueType::GELU)
309+
::gpu::xetla::EpilogueType::GELU)
310310
.build();
311311
if (policy.fallback() == false) {
312312
return !policy.run(handle);
@@ -321,10 +321,10 @@ RunXetlaGemm(se::gpu::GpuStreamHandle handle, const MatrixDescriptor& lhs,
321321
.add_matrix_b(rhs)
322322
.add_epilogue(
323323
bias_data,
324-
::gpu::xetla::XetlaGemmKernel<InputT>::EpilogueType::BIAS)
324+
::gpu::xetla::EpilogueType::BIAS)
325325
.add_epilogue(
326326
nullptr,
327-
::gpu::xetla::XetlaGemmKernel<InputT>::EpilogueType::GELU)
327+
::gpu::xetla::EpilogueType::GELU)
328328
.build();
329329
if (policy.fallback() == false) {
330330
return !policy.run(handle);

xla/service/gpu/xetla/gemm/BUILD

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,96 @@
11
load("//xla:xla.bzl", "xetla_library")
22

3-
# List all kernels here.
3+
xetla_library(
4+
name = "gemm_common",
5+
hdrs = [
6+
"gemm_common.h"
7+
],
8+
copts = [
9+
"-Wall",
10+
"-Wno-c++11-narrowing",
11+
],
12+
visibility = ["//visibility:public"],
13+
deps = [
14+
"//xla/service/gpu:matrix_descriptor",
15+
],
16+
)
17+
18+
xetla_library(
19+
name = "gemm_dispatch",
20+
hdrs = [
21+
"gemm_dispatch.h",
22+
"hgemm_impl.h",
23+
"epilogue_impl.h",
24+
],
25+
copts = [
26+
"-Wall",
27+
"-Wno-c++11-narrowing",
28+
],
29+
visibility = ["//visibility:public"],
30+
deps = [
31+
":gemm_common",
32+
"//xla/service/gpu:matrix_descriptor",
33+
"//xla/stream_executor/sycl:sycl_executor",
34+
"@xetla//:xetla_header",
35+
"@com_google_absl//absl/strings",
36+
],
37+
)
38+
39+
xetla_library(
40+
name = "dispatch_row_major",
41+
srcs = [
42+
"dispatch_row_major.cc",
43+
],
44+
hdrs = [
45+
"dispatch_row_major.h",
46+
],
47+
copts = [
48+
"-Wall",
49+
"-Wno-c++11-narrowing",
50+
],
51+
visibility = ["//visibility:public"],
52+
deps = [
53+
":gemm_dispatch",
54+
"//xla/stream_executor/sycl:sycl_executor",
55+
],
56+
)
57+
58+
xetla_library(
59+
name = "dispatch_col_major",
60+
srcs = [
61+
"dispatch_col_major.cc",
62+
],
63+
hdrs = [
64+
"dispatch_col_major.h",
65+
],
66+
copts = [
67+
"-Wall",
68+
"-Wno-c++11-narrowing",
69+
],
70+
visibility = ["//visibility:public"],
71+
deps = [
72+
":gemm_dispatch",
73+
"//xla/stream_executor/sycl:sycl_executor",
74+
],
75+
)
76+
477
xetla_library(
578
name = "gemm_kernel",
679
srcs = [
780
"gemm.cc",
881
],
982
hdrs = [
1083
"gemm.h",
11-
"hgemm_impl.h",
12-
"epilogue_impl.h",
1384
],
1485
copts = [
1586
"-Wall",
1687
"-Wno-c++11-narrowing",
1788
],
1889
visibility = ["//visibility:public"],
1990
deps = [
91+
":gemm_common",
92+
":dispatch_row_major",
93+
":dispatch_col_major",
2094
"//xla/service/gpu:matrix_descriptor",
2195
"//xla/stream_executor/sycl:sycl_executor",
2296
"@xetla//:xetla_header",
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/* Copyright (c) 2024 Intel Corporation
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/service/gpu/xetla/gemm/dispatch_col_major.h"
17+
18+
#include "xla/service/gpu/xetla/gemm/gemm_common.h"
19+
#include "xla/service/gpu/xetla/gemm/gemm_dispatch.h"
20+
#include "xla/stream_executor/gpu/gpu_types.h"
21+
22+
namespace gpu {
23+
namespace xetla {
24+
25+
template <typename ComputeType>
26+
bool GemmColMajorDispatcher<ComputeType>::run(se::gpu::GpuStreamHandle handle) {
27+
int WG_M = std::get<0>(selected_policy_id_);
28+
int WG_N = std::get<1>(selected_policy_id_);
29+
int SG_M = std::get<2>(selected_policy_id_);
30+
int SG_N = std::get<3>(selected_policy_id_);
31+
int SG_K = std::get<4>(selected_policy_id_);
32+
int SLM_KS = std::get<5>(selected_policy_id_);
33+
return gemm_policy<ComputeType>::call(WG_M, WG_N, SG_M, SG_N, SG_K, SLM_KS,
34+
this, handle);
35+
}
36+
37+
template <typename ComputeType>
38+
template <int WG_M, int WG_N, int SG_M, int SG_N, int SG_K, int SLM_KS>
39+
bool GemmColMajorDispatcher<ComputeType>::dispatch(
40+
se::gpu::GpuStreamHandle handle) {
41+
return do_dispatch<ComputeType, WG_M, WG_N, SG_M, SG_N, SG_K, SLM_KS, false>(
42+
handle, params_);
43+
}
44+
45+
template class GemmColMajorDispatcher<sycl::half>;
46+
template class GemmColMajorDispatcher<gpu::xetla::bf16>;
47+
48+
} // namespace xetla
49+
} // namespace gpu
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/* Copyright (c) 2024 Intel Corporation
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_SERVICE_GPU_XETLA_GEMM_DISPATCH_COL_MAJOR_H_
17+
#define XLA_SERVICE_GPU_XETLA_GEMM_DISPATCH_COL_MAJOR_H_
18+
19+
#include "xla/service/gpu/xetla/gemm/gemm_common.h"
20+
#include "xla/service/gpu/xetla/gemm/gemm_dispatch.h"
21+
#include "xla/stream_executor/gpu/gpu_types.h"
22+
23+
namespace gpu {
24+
namespace xetla {
25+
26+
template <typename ComputeType>
27+
class GemmColMajorDispatcher {
28+
public:
29+
GemmColMajorDispatcher() = default;
30+
31+
GemmColMajorDispatcher(
32+
DispatchParams* params,
33+
std::tuple<int, int, int, int, int, int> selected_policy_id)
34+
: params_(params), selected_policy_id_(selected_policy_id) {}
35+
36+
template <int WG_M, int WG_N, int SG_M, int SG_N, int SG_K, int SLM_KS>
37+
bool dispatch(se::gpu::GpuStreamHandle handle);
38+
39+
bool run(se::gpu::GpuStreamHandle handle);
40+
41+
private:
42+
DispatchParams* params_;
43+
std::tuple<int, int, int, int, int, int> selected_policy_id_;
44+
};
45+
46+
} // namespace xetla
47+
} // namespace gpu
48+
49+
#endif // XLA_SERVICE_GPU_XETLA_GEMM_DISPATCH_COL_MAJOR_H_
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/* Copyright (c) 2024 Intel Corporation
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/service/gpu/xetla/gemm/dispatch_row_major.h"
17+
18+
#include "xla/service/gpu/xetla/gemm/gemm_common.h"
19+
#include "xla/service/gpu/xetla/gemm/gemm_dispatch.h"
20+
#include "xla/stream_executor/gpu/gpu_types.h"
21+
22+
namespace gpu {
23+
namespace xetla {
24+
25+
template <typename ComputeType>
26+
bool GemmRowMajorDispatcher<ComputeType>::run(se::gpu::GpuStreamHandle handle) {
27+
int WG_M = std::get<0>(selected_policy_id_);
28+
int WG_N = std::get<1>(selected_policy_id_);
29+
int SG_M = std::get<2>(selected_policy_id_);
30+
int SG_N = std::get<3>(selected_policy_id_);
31+
int SG_K = std::get<4>(selected_policy_id_);
32+
int SLM_KS = std::get<5>(selected_policy_id_);
33+
return gemm_policy<ComputeType>::call(WG_M, WG_N, SG_M, SG_N, SG_K, SLM_KS,
34+
this, handle);
35+
}
36+
37+
template <typename ComputeType>
38+
template <int WG_M, int WG_N, int SG_M, int SG_N, int SG_K, int SLM_KS>
39+
bool GemmRowMajorDispatcher<ComputeType>::dispatch(
40+
se::gpu::GpuStreamHandle handle) {
41+
return do_dispatch<ComputeType, WG_M, WG_N, SG_M, SG_N, SG_K, SLM_KS, true>(
42+
handle, params_);
43+
}
44+
45+
template class GemmRowMajorDispatcher<sycl::half>;
46+
template class GemmRowMajorDispatcher<gpu::xetla::bf16>;
47+
48+
} // namespace xetla
49+
} // namespace gpu
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/* Copyright (c) 2024 Intel Corporation
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_SERVICE_GPU_XETLA_GEMM_DISPATCH_ROW_MAJOR_H_
17+
#define XLA_SERVICE_GPU_XETLA_GEMM_DISPATCH_ROW_MAJOR_H_
18+
19+
#include "xla/service/gpu/xetla/gemm/gemm_common.h"
20+
#include "xla/service/gpu/xetla/gemm/gemm_dispatch.h"
21+
#include "xla/stream_executor/gpu/gpu_types.h"
22+
23+
namespace gpu {
24+
namespace xetla {
25+
26+
template <typename ComputeType>
27+
class GemmRowMajorDispatcher {
28+
public:
29+
GemmRowMajorDispatcher() = default;
30+
31+
GemmRowMajorDispatcher(
32+
DispatchParams* params,
33+
std::tuple<int, int, int, int, int, int> selected_policy_id)
34+
: params_(params), selected_policy_id_(selected_policy_id) {}
35+
36+
template <int WG_M, int WG_N, int SG_M, int SG_N, int SG_K, int SLM_KS>
37+
bool dispatch(se::gpu::GpuStreamHandle handle);
38+
39+
bool run(se::gpu::GpuStreamHandle handle);
40+
41+
private:
42+
DispatchParams* params_;
43+
std::tuple<int, int, int, int, int, int> selected_policy_id_;
44+
};
45+
46+
} // namespace xetla
47+
} // namespace gpu
48+
49+
#endif // XLA_SERVICE_GPU_XETLA_GEMM_DISPATCH_ROW_MAJOR_H_

0 commit comments

Comments
 (0)