|
| 1 | +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
| 2 | +#include <ATen/BlasBackend.h> |
| 3 | +#include <ATen/WrapDimUtilsMulti.h> |
| 4 | +#include <ATen/ceil_div.h> |
| 5 | +#include <ATen/native/Resize.h> |
| 6 | +#include <ATen/native/mkldnn/xpu/detail/oneDNN.h> |
| 7 | +#include <ATen/native/xpu/Blas.h> |
| 8 | +#include <torch/library.h> |
| 9 | + |
| 10 | +#ifndef AT_PER_OPERATOR_HEADERS |
| 11 | +#include <ATen/Functions.h> |
| 12 | +#include <ATen/NativeFunctions.h> |
| 13 | +#else |
| 14 | +#include <ATen/ops/_addmm_activation_native.h> |
| 15 | +#include <ATen/ops/_efficientzerotensor.h> |
| 16 | +#include <ATen/ops/_scaled_mm_native.h> |
| 17 | +#include <ATen/ops/_unsafe_view_native.h> |
| 18 | +#include <ATen/ops/abs.h> |
| 19 | +#include <ATen/ops/addmm_native.h> |
| 20 | +#include <ATen/ops/addmv_native.h> |
| 21 | +#include <ATen/ops/baddbmm_native.h> |
| 22 | +#include <ATen/ops/bmm_native.h> |
| 23 | +#include <ATen/ops/copy_native.h> |
| 24 | +#include <ATen/ops/dot_native.h> |
| 25 | +#include <ATen/ops/empty.h> |
| 26 | +#include <ATen/ops/empty_strided.h> |
| 27 | +#include <ATen/ops/gelu.h> |
| 28 | +#include <ATen/ops/max.h> |
| 29 | +#include <ATen/ops/mm_native.h> |
| 30 | +#include <ATen/ops/mul.h> |
| 31 | +#include <ATen/ops/ones.h> |
| 32 | +#include <ATen/ops/relu.h> |
| 33 | +#include <ATen/ops/scalar_tensor_native.h> |
| 34 | +#include <ATen/ops/vdot_native.h> |
| 35 | +#endif |
| 36 | + |
| 37 | +namespace at::native { |
| 38 | + |
| 39 | +using at::blas::ScalingType; |
| 40 | +using at::blas::SwizzleType; |
| 41 | + |
| 42 | +namespace { |
| 43 | +/* |
| 44 | + * Scaling Type Determination: |
| 45 | + * --------------------------- |
| 46 | + * Conditions and corresponding Scaling Types: |
| 47 | + * |
| 48 | + * - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`: |
| 49 | + * - Returns BlockWise (with additional size checks). |
| 50 | + * |
| 51 | + * - Else if scale.numel() == 1: |
| 52 | + * - Returns TensorWise. |
| 53 | + * |
| 54 | + * - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) == |
| 55 | + * 1: |
| 56 | + * - Returns RowWise. |
| 57 | + * |
| 58 | + * - Otherwise: |
| 59 | + * - Returns Error. |
| 60 | + */ |
| 61 | + |
| 62 | +bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) { |
| 63 | + return at::isFloat8Type(t.scalar_type()) && |
| 64 | + scale.scalar_type() == at::kFloat && scale.numel() == 1; |
| 65 | +} |
| 66 | + |
| 67 | +bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) { |
| 68 | + return ( |
| 69 | + at::isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat && |
| 70 | + scale.dim() == 2 && scale.size(0) == t.size(0) && scale.size(1) == 1 && |
| 71 | + scale.is_contiguous()); |
| 72 | +} |
| 73 | + |
| 74 | +bool is_desired_scaling( |
| 75 | + const at::Tensor& t, |
| 76 | + const at::Tensor& scale, |
| 77 | + ScalingType desired_scaling) { |
| 78 | + auto result = desired_scaling == ScalingType::TensorWise |
| 79 | + ? is_tensorwise_scaling(t, scale) |
| 80 | + : is_rowwise_scaling(t, scale); |
| 81 | + return result; |
| 82 | +} |
| 83 | + |
| 84 | +std::pair<ScalingType, ScalingType> get_joint_scaling( |
| 85 | + std::initializer_list<std::pair<ScalingType, ScalingType>> options, |
| 86 | + const at::Tensor& a, |
| 87 | + const at::Tensor& b, |
| 88 | + const at::Tensor& scale_a, |
| 89 | + const at::Tensor& scale_b) { |
| 90 | + for (auto [lhs, rhs] : options) { |
| 91 | + if (is_desired_scaling(a, scale_a, lhs) && |
| 92 | + is_desired_scaling(b.t(), scale_b.t(), rhs)) { |
| 93 | + return {lhs, rhs}; |
| 94 | + } |
| 95 | + } |
| 96 | + TORCH_CHECK( |
| 97 | + false, |
| 98 | + "Invalid scaling configuration.\n" |
| 99 | + "- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n" |
| 100 | + "- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (", |
| 101 | + a.size(0), |
| 102 | + ", 1) and scale_b should be (1, ", |
| 103 | + b.size(1), |
| 104 | + "), and both should be contiguous.\n" |
| 105 | + "Got a.dtype()=", |
| 106 | + a.scalar_type(), |
| 107 | + ", scale_a.dtype()=", |
| 108 | + scale_a.scalar_type(), |
| 109 | + ", scale_a.size()=", |
| 110 | + scale_a.sizes(), |
| 111 | + ", scale_a.stride()=", |
| 112 | + scale_a.strides(), |
| 113 | + ", ", |
| 114 | + "b.dtype()=", |
| 115 | + b.scalar_type(), |
| 116 | + ", scale_b.dtype()=", |
| 117 | + scale_b.scalar_type(), |
| 118 | + ", scale_b.size()=", |
| 119 | + scale_b.sizes(), |
| 120 | + " and scale_b.stride()=", |
| 121 | + scale_b.strides()); |
| 122 | +} |
| 123 | + |
| 124 | +Tensor& _scaled_gemm( |
| 125 | + const Tensor& mat1, |
| 126 | + const Tensor& mat2, |
| 127 | + const Tensor& scale_a, |
| 128 | + const Tensor& scale_b, |
| 129 | + const ScalingType scaling_choice_a, |
| 130 | + const ScalingType scaling_choice_b, |
| 131 | + const std::optional<Tensor>& bias, |
| 132 | + const bool use_fast_accum, |
| 133 | + Tensor& out, |
| 134 | + const std::optional<Tensor>& alpha = std::nullopt) { |
| 135 | + // TODO: scale_result and alpha is not defined or used! |
| 136 | + std::optional<Tensor> scaled_result = std::nullopt; |
| 137 | + at::native::onednn::scaled_matmul( |
| 138 | + mat1, |
| 139 | + mat2, |
| 140 | + out, |
| 141 | + scale_a, |
| 142 | + scale_b, |
| 143 | + scaling_choice_a, |
| 144 | + scaling_choice_b, |
| 145 | + bias, |
| 146 | + scaled_result, |
| 147 | + use_fast_accum); |
| 148 | + |
| 149 | + return out; |
| 150 | +} |
| 151 | + |
| 152 | +} // namespace |
| 153 | + |
| 154 | +// Computes matrix multiply + bias while applying scaling to input and output |
| 155 | +// matrices Scales are only applicable when matrices are of Float8 type and |
| 156 | +// assumed to be equal to 1.0 by default. If output matrix type is 16 or 32-bit |
| 157 | +// type, scale_result is not applied. Known limitations: |
| 158 | +// - Only works if mat1 is row-major and mat2 is column-major |
| 159 | +// - Only works if matrices sizes are divisible by 32 |
| 160 | +// - If 1-dimensional tensors are used then scale_a should be size = |
| 161 | +// mat1.size(0) |
| 162 | +// and scale_b should have size = to mat2.size(1) |
| 163 | +// Arguments: |
| 164 | +// - `mat1`: the first operand of the matrix multiply, can be type |
| 165 | +// `torch.float8_e4m3fn` or `torch.float8_e5m2` |
| 166 | +// - `mat2`: the second operand of the matrix multiply, can be type |
| 167 | +// `torch.float8_e4m3fn` or `torch.float8_e5m2` |
| 168 | +// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16` |
| 169 | +// - `out_dtype`: the output dtype, can either be a float8 or a higher |
| 170 | +// precision floating point type |
| 171 | +// - `scale_a`: a tensor with the inverse scale of `mat1`, whose |
| 172 | +// shape/strides/dtype depend on the scaling scheme |
| 173 | +// - `scale_b`: a tensor with the inverse scale of `mat2`, whose |
| 174 | +// shape/strides/dtype depend on the scaling scheme |
| 175 | +// - `scale_result`: a scalar tensor with the scale of the output, only |
| 176 | +// utilized if the output is a float8 type |
| 177 | +// - `use_fast_accum`: Not applicable for XPU. For now, it should always be |
| 178 | +// false. |
| 179 | +// - `out`: a reference to the output tensor |
| 180 | + |
| 181 | +Tensor& _scaled_mm_out_xpu( |
| 182 | + const Tensor& mat1, |
| 183 | + const Tensor& mat2, |
| 184 | + const Tensor& scale_a, |
| 185 | + const Tensor& scale_b, |
| 186 | + const std::optional<at::Tensor>& bias, |
| 187 | + const std::optional<at::Tensor>& scale_result, |
| 188 | + std::optional<c10::ScalarType> out_dtype, |
| 189 | + bool use_fast_accum, |
| 190 | + Tensor& out) { |
| 191 | + // Note: fast_accum is not supported in XPU for now. |
| 192 | + TORCH_CHECK(!use_fast_accum, "fast_accum is not supported in XPU for now."); |
| 193 | + |
| 194 | + TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix"); |
| 195 | + TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); |
| 196 | + |
| 197 | + TORCH_CHECK( |
| 198 | + mat1.sizes()[1] == mat2.sizes()[0], |
| 199 | + "mat1 and mat2 shapes cannot be multiplied (", |
| 200 | + mat1.sizes()[0], |
| 201 | + "x", |
| 202 | + mat1.sizes()[1], |
| 203 | + " and ", |
| 204 | + mat2.sizes()[0], |
| 205 | + "x", |
| 206 | + mat2.sizes()[1], |
| 207 | + ")"); |
| 208 | + |
| 209 | + // Check what type of scaling we are doing based on inputs. This list is |
| 210 | + // sorted by decreasing priority. |
| 211 | + |
| 212 | + // List of supported datatypes for XPU with oneDNN: |
| 213 | + // https://uxlfoundation.github.io/oneDNN/dev_guide_matmul.html#data-types |
| 214 | + auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling( |
| 215 | + { |
| 216 | + std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise), |
| 217 | + std::make_pair(ScalingType::RowWise, ScalingType::RowWise), |
| 218 | + }, |
| 219 | + mat1, |
| 220 | + mat2, |
| 221 | + scale_a, |
| 222 | + scale_b); |
| 223 | + TORCH_CHECK( |
| 224 | + !scale_result || |
| 225 | + (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat), |
| 226 | + "scale_result must be a float scalar"); |
| 227 | + TORCH_CHECK( |
| 228 | + !bias || bias->numel() == mat2.sizes()[1], |
| 229 | + "Bias must be size ", |
| 230 | + mat2.sizes()[1], |
| 231 | + " but got ", |
| 232 | + bias->numel()); |
| 233 | + TORCH_CHECK( |
| 234 | + mat1.sizes()[1] % 16 == 0, |
| 235 | + "Expected trailing dimension of mat1 to be divisible by 16 ", |
| 236 | + "but got mat1 shape: (", |
| 237 | + mat1.sizes()[0], |
| 238 | + "x", |
| 239 | + mat1.sizes()[1], |
| 240 | + ")."); |
| 241 | + TORCH_CHECK( |
| 242 | + mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0, |
| 243 | + "mat2 shape (", |
| 244 | + mat2.sizes()[0], |
| 245 | + "x", |
| 246 | + mat2.sizes()[1], |
| 247 | + ") must be divisible by 16"); |
| 248 | + // Check types |
| 249 | + TORCH_CHECK( |
| 250 | + !out_dtype || *out_dtype == out.scalar_type(), |
| 251 | + "out_dtype must match output matrix type"); |
| 252 | + TORCH_CHECK( |
| 253 | + at::isFloat8Type(mat1.scalar_type()), |
| 254 | + "Expected mat1 to be Float8 matrix got ", |
| 255 | + mat1.scalar_type()); |
| 256 | + TORCH_CHECK( |
| 257 | + at::isFloat8Type(mat2.scalar_type()), |
| 258 | + "Expected mat2 to be Float8 matrix got ", |
| 259 | + mat2.scalar_type()); |
| 260 | + // TODO: oneDNN Currently only supports e4m3 with group scales on BMG. Not |
| 261 | + // support 2D scales, only 1D. Needs to add more checks there. |
| 262 | + |
| 263 | + if (bias) { |
| 264 | + TORCH_CHECK( |
| 265 | + bias->scalar_type() == kFloat || |
| 266 | + bias->scalar_type() == c10::ScalarType::BFloat16 || |
| 267 | + bias->scalar_type() == c10::ScalarType::Half, |
| 268 | + "Bias must be Float32 or BFloat16 or Half, but got ", |
| 269 | + bias->scalar_type()); |
| 270 | + } |
| 271 | + |
| 272 | + { |
| 273 | + auto bias_ = bias.value_or(Tensor()); |
| 274 | + auto scale_result_ = scale_result.value_or(Tensor()); |
| 275 | + |
| 276 | + // NOLINTNEXTLINE(*c-array*) |
| 277 | + TensorArg targs[]{ |
| 278 | + {out, "out", 0}, |
| 279 | + {mat1, "mat1", 1}, |
| 280 | + {mat2, "mat2", 2}, |
| 281 | + {bias_, "bias", 3}, |
| 282 | + {scale_a, "scale_a", 4}, |
| 283 | + {scale_b, "scale_b", 5}, |
| 284 | + {scale_result_, "scale_result", 6}}; |
| 285 | + checkAllSameGPU(__func__, targs); |
| 286 | + } |
| 287 | + |
| 288 | + // Validation checks have passed lets resize the output to actual size |
| 289 | + IntArrayRef mat1_sizes = mat1.sizes(); |
| 290 | + IntArrayRef mat2_sizes = mat2.sizes(); |
| 291 | + at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); |
| 292 | + |
| 293 | + // If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm |
| 294 | + // kernels do not support this case). |
| 295 | + if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) { |
| 296 | + // `out` was created with `at::empty`. In the case where we are multiplying |
| 297 | + // MxK by KxN and K is the zero dim, we need to initialize here to properly |
| 298 | + // return a tensor of zeros. |
| 299 | + if (mat1_sizes[1] == 0) { |
| 300 | + out.zero_(); |
| 301 | + } |
| 302 | + |
| 303 | + return out; |
| 304 | + } |
| 305 | + |
| 306 | + // TODO: Scale_result is not supported by now!! |
| 307 | + return _scaled_gemm( |
| 308 | + mat1, |
| 309 | + mat2, |
| 310 | + scale_a, |
| 311 | + scale_b, |
| 312 | + scaling_choice_a, |
| 313 | + scaling_choice_b, |
| 314 | + bias, |
| 315 | + use_fast_accum, |
| 316 | + out); |
| 317 | +} |
| 318 | + |
| 319 | +Tensor _scaled_mm_xpu( |
| 320 | + const Tensor& mat_a, |
| 321 | + const Tensor& mat_b, |
| 322 | + const Tensor& scale_a, |
| 323 | + const Tensor& scale_b, |
| 324 | + const std::optional<at::Tensor>& bias, |
| 325 | + const std::optional<at::Tensor>& scale_result, |
| 326 | + std::optional<c10::ScalarType> out_dtype, |
| 327 | + bool use_fast_accum) { |
| 328 | + const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type()); |
| 329 | + Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_)); |
| 330 | + return _scaled_mm_out_xpu( |
| 331 | + mat_a, |
| 332 | + mat_b, |
| 333 | + scale_a, |
| 334 | + scale_b, |
| 335 | + bias, |
| 336 | + scale_result, |
| 337 | + out_dtype, |
| 338 | + use_fast_accum, |
| 339 | + out); |
| 340 | +} |
| 341 | + |
| 342 | +} // namespace at::native |
0 commit comments