Skip to content

Commit 583de90

Browse files
zzf/add custom fallback for addmm linear bmm mm (#566)
* add custom fallback for addmm linear bmm mm * add custom fallback for addmm linear bmm mm * add custom fallback for addmm linear bmm mm * add custom fallback for addmm linear bmm mm
1 parent 99dadd7 commit 583de90

File tree

2 files changed

+70
-3
lines changed

2 files changed

+70
-3
lines changed

dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,7 @@
481481
interface: diopiSum(ctx, out, self_dtype_diopi, diopi_size)
482482

483483
- schema: "addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)"
484+
custom_fallback: True
484485
custom_code_at_the_beginning: |
485486
interface: diopiAddmm(&context, out, self, mat1, mat2, beta, alpha)
486487

@@ -744,6 +745,7 @@
744745
interface: diopiLinearBackward(ctx, grad_input, grad_weight, grad_bias, grad_output, input, weight)
745746

746747
- schema: "linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor"
748+
custom_fallback: True
747749
device: [all, -cuda]
748750
custom_code_at_the_beginning: |
749751
std::vector<int64_t> output_size(input.sizes().begin(), input.sizes().end());
@@ -1470,6 +1472,7 @@
14701472
interface: diopiCosInp(ctx, self)
14711473

14721474
- schema: "bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)"
1475+
custom_fallback: True
14731476
interface: diopiBmm(ctx, out, self, mat2)
14741477

14751478
- schema: "silu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"
@@ -1484,6 +1487,7 @@
14841487
interface: diopiNormalInp(ctx, self, mean, std, generator)
14851488

14861489
- schema: "mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)"
1490+
custom_fallback: True
14871491
interface: diopiMm(ctx, out, self, mat2)
14881492

14891493
- schema: "matmul(Tensor self, Tensor other) -> Tensor"
@@ -2434,6 +2438,7 @@
24342438

24352439
# this copy_ aten op may use both diopiCastDtype and diopiCopyInp. it's a proxy/composite op
24362440
- schema: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
2441+
autocompare: disable
24372442
dummy_call_diopi: True
24382443
custom_fallback: True
24392444
device: [cuda, camb, ascend, droplet, supa, kunlunxin]
@@ -2445,6 +2450,7 @@
24452450

24462451
# vendor who has no fully implemented diopi and proper fallback DIPUCopy sub-class
24472452
- schema: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
2453+
autocompare: disable
24482454
custom_fallback: True
24492455
dummy_call_diopi: True
24502456
custom_code_at_the_beginning: |
@@ -2453,6 +2459,7 @@
24532459
interface: diopiCopyInp(ctx, src, self)
24542460

24552461
- schema: _amp_foreach_non_finite_check_and_unscale_(at::TensorList self, Tensor(b!) found_inf, Tensor inv_scale) -> void
2462+
autocompare: disable
24562463
custom_fallback: True
24572464
custom_code_at_the_beginning: |
24582465
std::vector<diopiTensorHandle_t> diopiTensorHandles(self.size(), nullptr);

dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctions.hpp

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ static c10::optional<at::Tensor> dipu_to_cpu(
1717
return cpu_tensor;
1818
}
1919

20-
static at::Tensor to_cpu_no_half(const at::Tensor& devtensor) {
20+
static at::Tensor to_cpu_with_half_to_float(const at::Tensor& devtensor) {
2121
auto cpu_tensor = devtensor.cpu();
2222
auto intype = devtensor.options().dtype_opt()->toScalarType();
2323
if (intype == at::ScalarType::Half) {
@@ -30,8 +30,9 @@ static at::Tensor& custom_fallback_dipu_silu_out(const at::Tensor& self,
3030
at::Tensor& out) {
3131
DIPU_OP_LOG_WARNING_ONCE("custom fallback to cpu, name=silu_out"
3232
<< std::endl);
33-
auto self_cpu = to_cpu_no_half(self);
34-
auto out_cpu = to_cpu_no_half(self);
33+
auto self_cpu = to_cpu_with_half_to_float(self);
34+
auto out_cpu = to_cpu_with_half_to_float(self);
35+
3536
// NOLINTNEXTLINE(readability-suspicious-call-argument): It's the correct order
3637
out_cpu = at::silu_out(self_cpu, out_cpu);
3738
out.copy_(out_cpu);
@@ -339,5 +340,64 @@ at::Tensor& custom_fallback_dipu__amp_update_scale_(at::Tensor& current_scale,
339340
double backoff_factor,
340341
int64_t growth_interval);
341342

343+
static at::Tensor& custom_fallback_dipu_addmm_out(
344+
const at::Tensor& self, const at::Tensor& mat1, const at::Tensor& mat2,
345+
const at::Scalar& beta, const at::Scalar& alpha, at::Tensor& out) {
346+
auto self_cpu = to_cpu_with_half_to_float(self);
347+
auto mat1_cpu = to_cpu_with_half_to_float(mat1);
348+
auto mat2_cpu = to_cpu_with_half_to_float(mat2);
349+
auto out_cpu = to_cpu_with_half_to_float(out);
350+
out_cpu = at::addmm_out(out_cpu, self_cpu, mat1_cpu, mat2_cpu, beta, alpha);
351+
out.copy_(out_cpu);
352+
return out;
353+
}
354+
355+
static at::Tensor& custom_fallback_dipu_bmm_out(const at::Tensor& self,
356+
const at::Tensor& mat2,
357+
at::Tensor& out) {
358+
auto self_cpu = to_cpu_with_half_to_float(self);
359+
auto mat2_cpu = to_cpu_with_half_to_float(mat2);
360+
auto out_cpu = to_cpu_with_half_to_float(out);
361+
out_cpu = at::bmm_out(out_cpu, self_cpu, mat2_cpu);
362+
out.copy_(out_cpu);
363+
return out;
364+
}
365+
366+
static at::Tensor& custom_fallback_dipu_mm_out(const at::Tensor& self,
367+
const at::Tensor& mat2,
368+
at::Tensor& out) {
369+
auto self_cpu = to_cpu_with_half_to_float(self);
370+
auto mat2_cpu = to_cpu_with_half_to_float(mat2);
371+
auto out_cpu = to_cpu_with_half_to_float(out);
372+
out_cpu = at::mm_out(out_cpu, self_cpu, mat2_cpu);
373+
out.copy_(out_cpu);
374+
return out;
375+
}
376+
377+
static at::Tensor custom_fallback_dipu_linear(
378+
const at::Tensor& input, const at::Tensor& weight,
379+
const c10::optional<at::Tensor>& bias) {
380+
auto input_cpu = to_cpu_with_half_to_float(input);
381+
auto weight_cpu = to_cpu_with_half_to_float(weight);
382+
c10::optional<at::Tensor> bias_cpu = c10::nullopt;
383+
384+
at::Tensor out;
385+
at::Tensor out_cpu;
386+
387+
if (bias.has_value() && bias.value().defined()) {
388+
if (bias.value().options().dtype_opt()->toScalarType() ==
389+
at::ScalarType::Half) {
390+
bias_cpu = bias.value().to(at::ScalarType::Float).cpu();
391+
} else {
392+
bias_cpu = bias.value().cpu();
393+
}
394+
}
395+
396+
out_cpu = at::linear(input_cpu, weight_cpu, bias_cpu);
397+
out = out_cpu.to(input.device())
398+
.to(input.options().dtype_opt()->toScalarType());
399+
return out;
400+
}
401+
342402
} // namespace native
343403
} // namespace dipu

0 commit comments

Comments
 (0)