Skip to content

Commit 0e7235e

Browse files
StonepiaEikanWang
authored andcommitted
[xpu][feature] [1/3] add fp8 scaled_mm implementation for XPU (pytorch#165978)
This PR implements `scaled_mm` for XPU. It enables the following data types: 1. TensorWise Scaling: `fp8_e4m3` and `fp8_e5m2` 2. RowWise Scaling: `fp8_e4m3` and `fp8_e5m2` It leaves the BlockWise Scaling to next PR, so that it will have less reviewing efforts. This is the first PR that only adds `scaled_mm_xpu` but does not registered. We separate this out for less reviewing efforts. Secondly, there is a `scaled_mm_v2` API in pytorch#164141 . We will align with it once the v1 is cleaned up. **Co-author:** @yuchengliu1, @carsonwang ## PR stack: - -> pytorch#165978 : implementation of XPU scaled_mm and oneDNN kernel - pytorch#167518 : implementation of XPU scaled_mm_v2 - pytorch#166056 : Op registration ## Test Status: 1. Relies on the changes in intel/torch-xpu-ops#1746, Otherwise the op will fallback to CPU. 2. This PR does not include tests, the tests are enabled in pytorch#166056. ## Credit: This work is based on @yuchengliu1's work at pytorch#140972 . The purpose that we created a new PR is to align with the API / checks with CUDA, so there will be less porting efforts. ## FP8 Task tracker: We will track all the scaled_mm related tasks in: pytorch#167170 Pull Request resolved: pytorch#165978 Approved by: https://github.com/liangan1, https://github.com/EikanWang Co-authored-by: Eikan Wang <[email protected]>
1 parent 3522e0c commit 0e7235e

File tree

4 files changed

+592
-1
lines changed

4 files changed

+592
-1
lines changed
Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2+
#include <ATen/BlasBackend.h>
3+
#include <ATen/WrapDimUtilsMulti.h>
4+
#include <ATen/ceil_div.h>
5+
#include <ATen/native/Resize.h>
6+
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
7+
#include <ATen/native/xpu/Blas.h>
8+
#include <torch/library.h>
9+
10+
#ifndef AT_PER_OPERATOR_HEADERS
11+
#include <ATen/Functions.h>
12+
#include <ATen/NativeFunctions.h>
13+
#else
14+
#include <ATen/ops/_addmm_activation_native.h>
15+
#include <ATen/ops/_efficientzerotensor.h>
16+
#include <ATen/ops/_scaled_mm_native.h>
17+
#include <ATen/ops/_unsafe_view_native.h>
18+
#include <ATen/ops/abs.h>
19+
#include <ATen/ops/addmm_native.h>
20+
#include <ATen/ops/addmv_native.h>
21+
#include <ATen/ops/baddbmm_native.h>
22+
#include <ATen/ops/bmm_native.h>
23+
#include <ATen/ops/copy_native.h>
24+
#include <ATen/ops/dot_native.h>
25+
#include <ATen/ops/empty.h>
26+
#include <ATen/ops/empty_strided.h>
27+
#include <ATen/ops/gelu.h>
28+
#include <ATen/ops/max.h>
29+
#include <ATen/ops/mm_native.h>
30+
#include <ATen/ops/mul.h>
31+
#include <ATen/ops/ones.h>
32+
#include <ATen/ops/relu.h>
33+
#include <ATen/ops/scalar_tensor_native.h>
34+
#include <ATen/ops/vdot_native.h>
35+
#endif
36+
37+
namespace at::native {
38+
39+
using at::blas::ScalingType;
40+
using at::blas::SwizzleType;
41+
42+
namespace {
43+
/*
44+
* Scaling Type Determination:
45+
* ---------------------------
46+
* Conditions and corresponding Scaling Types:
47+
*
48+
* - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`:
49+
* - Returns BlockWise (with additional size checks).
50+
*
51+
* - Else if scale.numel() == 1:
52+
* - Returns TensorWise.
53+
*
54+
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) ==
55+
* 1:
56+
* - Returns RowWise.
57+
*
58+
* - Otherwise:
59+
* - Returns Error.
60+
*/
61+
62+
bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
63+
return at::isFloat8Type(t.scalar_type()) &&
64+
scale.scalar_type() == at::kFloat && scale.numel() == 1;
65+
}
66+
67+
bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
68+
return (
69+
at::isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat &&
70+
scale.dim() == 2 && scale.size(0) == t.size(0) && scale.size(1) == 1 &&
71+
scale.is_contiguous());
72+
}
73+
74+
bool is_desired_scaling(
75+
const at::Tensor& t,
76+
const at::Tensor& scale,
77+
ScalingType desired_scaling) {
78+
auto result = desired_scaling == ScalingType::TensorWise
79+
? is_tensorwise_scaling(t, scale)
80+
: is_rowwise_scaling(t, scale);
81+
return result;
82+
}
83+
84+
std::pair<ScalingType, ScalingType> get_joint_scaling(
85+
std::initializer_list<std::pair<ScalingType, ScalingType>> options,
86+
const at::Tensor& a,
87+
const at::Tensor& b,
88+
const at::Tensor& scale_a,
89+
const at::Tensor& scale_b) {
90+
for (auto [lhs, rhs] : options) {
91+
if (is_desired_scaling(a, scale_a, lhs) &&
92+
is_desired_scaling(b.t(), scale_b.t(), rhs)) {
93+
return {lhs, rhs};
94+
}
95+
}
96+
TORCH_CHECK(
97+
false,
98+
"Invalid scaling configuration.\n"
99+
"- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n"
100+
"- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (",
101+
a.size(0),
102+
", 1) and scale_b should be (1, ",
103+
b.size(1),
104+
"), and both should be contiguous.\n"
105+
"Got a.dtype()=",
106+
a.scalar_type(),
107+
", scale_a.dtype()=",
108+
scale_a.scalar_type(),
109+
", scale_a.size()=",
110+
scale_a.sizes(),
111+
", scale_a.stride()=",
112+
scale_a.strides(),
113+
", ",
114+
"b.dtype()=",
115+
b.scalar_type(),
116+
", scale_b.dtype()=",
117+
scale_b.scalar_type(),
118+
", scale_b.size()=",
119+
scale_b.sizes(),
120+
" and scale_b.stride()=",
121+
scale_b.strides());
122+
}
123+
124+
Tensor& _scaled_gemm(
125+
const Tensor& mat1,
126+
const Tensor& mat2,
127+
const Tensor& scale_a,
128+
const Tensor& scale_b,
129+
const ScalingType scaling_choice_a,
130+
const ScalingType scaling_choice_b,
131+
const std::optional<Tensor>& bias,
132+
const bool use_fast_accum,
133+
Tensor& out,
134+
const std::optional<Tensor>& alpha = std::nullopt) {
135+
// TODO: scale_result and alpha is not defined or used!
136+
std::optional<Tensor> scaled_result = std::nullopt;
137+
at::native::onednn::scaled_matmul(
138+
mat1,
139+
mat2,
140+
out,
141+
scale_a,
142+
scale_b,
143+
scaling_choice_a,
144+
scaling_choice_b,
145+
bias,
146+
scaled_result,
147+
use_fast_accum);
148+
149+
return out;
150+
}
151+
152+
} // namespace
153+
154+
// Computes matrix multiply + bias while applying scaling to input and output
155+
// matrices Scales are only applicable when matrices are of Float8 type and
156+
// assumed to be equal to 1.0 by default. If output matrix type is 16 or 32-bit
157+
// type, scale_result is not applied. Known limitations:
158+
// - Only works if mat1 is row-major and mat2 is column-major
159+
// - Only works if matrices sizes are divisible by 32
160+
// - If 1-dimensional tensors are used then scale_a should be size =
161+
// mat1.size(0)
162+
// and scale_b should have size = to mat2.size(1)
163+
// Arguments:
164+
// - `mat1`: the first operand of the matrix multiply, can be type
165+
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
166+
// - `mat2`: the second operand of the matrix multiply, can be type
167+
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
168+
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
169+
// - `out_dtype`: the output dtype, can either be a float8 or a higher
170+
// precision floating point type
171+
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose
172+
// shape/strides/dtype depend on the scaling scheme
173+
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose
174+
// shape/strides/dtype depend on the scaling scheme
175+
// - `scale_result`: a scalar tensor with the scale of the output, only
176+
// utilized if the output is a float8 type
177+
// - `use_fast_accum`: Not applicable for XPU. For now, it should always be
178+
// false.
179+
// - `out`: a reference to the output tensor
180+
181+
Tensor& _scaled_mm_out_xpu(
182+
const Tensor& mat1,
183+
const Tensor& mat2,
184+
const Tensor& scale_a,
185+
const Tensor& scale_b,
186+
const std::optional<at::Tensor>& bias,
187+
const std::optional<at::Tensor>& scale_result,
188+
std::optional<c10::ScalarType> out_dtype,
189+
bool use_fast_accum,
190+
Tensor& out) {
191+
// Note: fast_accum is not supported in XPU for now.
192+
TORCH_CHECK(!use_fast_accum, "fast_accum is not supported in XPU for now.");
193+
194+
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
195+
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
196+
197+
TORCH_CHECK(
198+
mat1.sizes()[1] == mat2.sizes()[0],
199+
"mat1 and mat2 shapes cannot be multiplied (",
200+
mat1.sizes()[0],
201+
"x",
202+
mat1.sizes()[1],
203+
" and ",
204+
mat2.sizes()[0],
205+
"x",
206+
mat2.sizes()[1],
207+
")");
208+
209+
// Check what type of scaling we are doing based on inputs. This list is
210+
// sorted by decreasing priority.
211+
212+
// List of supported datatypes for XPU with oneDNN:
213+
// https://uxlfoundation.github.io/oneDNN/dev_guide_matmul.html#data-types
214+
auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling(
215+
{
216+
std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise),
217+
std::make_pair(ScalingType::RowWise, ScalingType::RowWise),
218+
},
219+
mat1,
220+
mat2,
221+
scale_a,
222+
scale_b);
223+
TORCH_CHECK(
224+
!scale_result ||
225+
(scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
226+
"scale_result must be a float scalar");
227+
TORCH_CHECK(
228+
!bias || bias->numel() == mat2.sizes()[1],
229+
"Bias must be size ",
230+
mat2.sizes()[1],
231+
" but got ",
232+
bias->numel());
233+
TORCH_CHECK(
234+
mat1.sizes()[1] % 16 == 0,
235+
"Expected trailing dimension of mat1 to be divisible by 16 ",
236+
"but got mat1 shape: (",
237+
mat1.sizes()[0],
238+
"x",
239+
mat1.sizes()[1],
240+
").");
241+
TORCH_CHECK(
242+
mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0,
243+
"mat2 shape (",
244+
mat2.sizes()[0],
245+
"x",
246+
mat2.sizes()[1],
247+
") must be divisible by 16");
248+
// Check types
249+
TORCH_CHECK(
250+
!out_dtype || *out_dtype == out.scalar_type(),
251+
"out_dtype must match output matrix type");
252+
TORCH_CHECK(
253+
at::isFloat8Type(mat1.scalar_type()),
254+
"Expected mat1 to be Float8 matrix got ",
255+
mat1.scalar_type());
256+
TORCH_CHECK(
257+
at::isFloat8Type(mat2.scalar_type()),
258+
"Expected mat2 to be Float8 matrix got ",
259+
mat2.scalar_type());
260+
// TODO: oneDNN Currently only supports e4m3 with group scales on BMG. Not
261+
// support 2D scales, only 1D. Needs to add more checks there.
262+
263+
if (bias) {
264+
TORCH_CHECK(
265+
bias->scalar_type() == kFloat ||
266+
bias->scalar_type() == c10::ScalarType::BFloat16 ||
267+
bias->scalar_type() == c10::ScalarType::Half,
268+
"Bias must be Float32 or BFloat16 or Half, but got ",
269+
bias->scalar_type());
270+
}
271+
272+
{
273+
auto bias_ = bias.value_or(Tensor());
274+
auto scale_result_ = scale_result.value_or(Tensor());
275+
276+
// NOLINTNEXTLINE(*c-array*)
277+
TensorArg targs[]{
278+
{out, "out", 0},
279+
{mat1, "mat1", 1},
280+
{mat2, "mat2", 2},
281+
{bias_, "bias", 3},
282+
{scale_a, "scale_a", 4},
283+
{scale_b, "scale_b", 5},
284+
{scale_result_, "scale_result", 6}};
285+
checkAllSameGPU(__func__, targs);
286+
}
287+
288+
// Validation checks have passed lets resize the output to actual size
289+
IntArrayRef mat1_sizes = mat1.sizes();
290+
IntArrayRef mat2_sizes = mat2.sizes();
291+
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
292+
293+
// If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm
294+
// kernels do not support this case).
295+
if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) {
296+
// `out` was created with `at::empty`. In the case where we are multiplying
297+
// MxK by KxN and K is the zero dim, we need to initialize here to properly
298+
// return a tensor of zeros.
299+
if (mat1_sizes[1] == 0) {
300+
out.zero_();
301+
}
302+
303+
return out;
304+
}
305+
306+
// TODO: Scale_result is not supported by now!!
307+
return _scaled_gemm(
308+
mat1,
309+
mat2,
310+
scale_a,
311+
scale_b,
312+
scaling_choice_a,
313+
scaling_choice_b,
314+
bias,
315+
use_fast_accum,
316+
out);
317+
}
318+
319+
Tensor _scaled_mm_xpu(
320+
const Tensor& mat_a,
321+
const Tensor& mat_b,
322+
const Tensor& scale_a,
323+
const Tensor& scale_b,
324+
const std::optional<at::Tensor>& bias,
325+
const std::optional<at::Tensor>& scale_result,
326+
std::optional<c10::ScalarType> out_dtype,
327+
bool use_fast_accum) {
328+
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
329+
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
330+
return _scaled_mm_out_xpu(
331+
mat_a,
332+
mat_b,
333+
scale_a,
334+
scale_b,
335+
bias,
336+
scale_result,
337+
out_dtype,
338+
use_fast_accum,
339+
out);
340+
}
341+
342+
} // namespace at::native

0 commit comments

Comments
 (0)