@@ -17,7 +17,7 @@ static c10::optional<at::Tensor> dipu_to_cpu(
1717 return cpu_tensor;
1818}
1919
20- static at::Tensor to_cpu_no_half (const at::Tensor& devtensor) {
20+ static at::Tensor to_cpu_with_half_to_float (const at::Tensor& devtensor) {
2121 auto cpu_tensor = devtensor.cpu ();
2222 auto intype = devtensor.options ().dtype_opt ()->toScalarType ();
2323 if (intype == at::ScalarType::Half) {
@@ -30,8 +30,9 @@ static at::Tensor& custom_fallback_dipu_silu_out(const at::Tensor& self,
3030 at::Tensor& out) {
3131 DIPU_OP_LOG_WARNING_ONCE (" custom fallback to cpu, name=silu_out"
3232 << std::endl);
33- auto self_cpu = to_cpu_no_half (self);
34- auto out_cpu = to_cpu_no_half (self);
33+ auto self_cpu = to_cpu_with_half_to_float (self);
34+ auto out_cpu = to_cpu_with_half_to_float (self);
35+
3536 // NOLINTNEXTLINE(readability-suspicious-call-argument): It's the correct order
3637 out_cpu = at::silu_out (self_cpu, out_cpu);
3738 out.copy_ (out_cpu);
@@ -339,5 +340,64 @@ at::Tensor& custom_fallback_dipu__amp_update_scale_(at::Tensor& current_scale,
339340 double backoff_factor,
340341 int64_t growth_interval);
341342
343+ static at::Tensor& custom_fallback_dipu_addmm_out (
344+ const at::Tensor& self, const at::Tensor& mat1, const at::Tensor& mat2,
345+ const at::Scalar& beta, const at::Scalar& alpha, at::Tensor& out) {
346+ auto self_cpu = to_cpu_with_half_to_float (self);
347+ auto mat1_cpu = to_cpu_with_half_to_float (mat1);
348+ auto mat2_cpu = to_cpu_with_half_to_float (mat2);
349+ auto out_cpu = to_cpu_with_half_to_float (out);
350+ out_cpu = at::addmm_out (out_cpu, self_cpu, mat1_cpu, mat2_cpu, beta, alpha);
351+ out.copy_ (out_cpu);
352+ return out;
353+ }
354+
355+ static at::Tensor& custom_fallback_dipu_bmm_out (const at::Tensor& self,
356+ const at::Tensor& mat2,
357+ at::Tensor& out) {
358+ auto self_cpu = to_cpu_with_half_to_float (self);
359+ auto mat2_cpu = to_cpu_with_half_to_float (mat2);
360+ auto out_cpu = to_cpu_with_half_to_float (out);
361+ out_cpu = at::bmm_out (out_cpu, self_cpu, mat2_cpu);
362+ out.copy_ (out_cpu);
363+ return out;
364+ }
365+
366+ static at::Tensor& custom_fallback_dipu_mm_out (const at::Tensor& self,
367+ const at::Tensor& mat2,
368+ at::Tensor& out) {
369+ auto self_cpu = to_cpu_with_half_to_float (self);
370+ auto mat2_cpu = to_cpu_with_half_to_float (mat2);
371+ auto out_cpu = to_cpu_with_half_to_float (out);
372+ out_cpu = at::mm_out (out_cpu, self_cpu, mat2_cpu);
373+ out.copy_ (out_cpu);
374+ return out;
375+ }
376+
377+ static at::Tensor custom_fallback_dipu_linear (
378+ const at::Tensor& input, const at::Tensor& weight,
379+ const c10::optional<at::Tensor>& bias) {
380+ auto input_cpu = to_cpu_with_half_to_float (input);
381+ auto weight_cpu = to_cpu_with_half_to_float (weight);
382+ c10::optional<at::Tensor> bias_cpu = c10::nullopt ;
383+
384+ at::Tensor out;
385+ at::Tensor out_cpu;
386+
387+ if (bias.has_value () && bias.value ().defined ()) {
388+ if (bias.value ().options ().dtype_opt ()->toScalarType () ==
389+ at::ScalarType::Half) {
390+ bias_cpu = bias.value ().to (at::ScalarType::Float).cpu ();
391+ } else {
392+ bias_cpu = bias.value ().cpu ();
393+ }
394+ }
395+
396+ out_cpu = at::linear (input_cpu, weight_cpu, bias_cpu);
397+ out = out_cpu.to (input.device ())
398+ .to (input.options ().dtype_opt ()->toScalarType ());
399+ return out;
400+ }
401+
342402} // namespace native
343403} // namespace dipu
0 commit comments