Skip to content

Commit a94350c

Browse files
authored
fix MM nullptr from zero bias
Differential Revision: D80487955 Pull Request resolved: #13523
1 parent 14bf790 commit a94350c

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

backends/cadence/hifi/kernels/kernels.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,19 @@ memcpy(void* dst, const void* src, size_t num_bytes) {
2121
}
2222

2323
void* allocate_temp_memory(KernelRuntimeContext& ctx, size_t size) {
24+
ET_LOG(Info, "Attempting to allocate %zu bytes of temp memory", size);
2425
Result<void*> temp_mem_res = ctx.allocate_temp(size);
25-
return temp_mem_res.ok() ? temp_mem_res.get() : nullptr;
26+
if (temp_mem_res.ok()) {
27+
void* ptr = temp_mem_res.get();
28+
ET_LOG(Info, "Successfully allocated temp memory at %p", ptr);
29+
return ptr;
30+
} else {
31+
ET_LOG(
32+
Error,
33+
"Failed to allocate temp memory, error: 0x%x",
34+
static_cast<uint32_t>(temp_mem_res.error()));
35+
return nullptr;
36+
}
2637
}
2738

2839
// Quantize a fp32 value to an int8_t/uint8_t value

backends/cadence/hifi/operators/op_mm.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,15 @@ Tensor& mm_out(
7979
(WORD32* __restrict__)kernels::allocate_temp_memory(
8080
ctx, (n * p) * sizeof(WORD32));
8181

82+
// Allocate zero-initialized bias for matmul function (it doesn't accept
83+
// NULL)
84+
FLOAT32* __restrict__ p_bias_zero =
85+
(FLOAT32* __restrict__)kernels::allocate_temp_memory(
86+
ctx, m * sizeof(FLOAT32));
87+
88+
// Initialize bias to zero since mm operation has no bias
89+
memset(p_bias_zero, 0, m * sizeof(FLOAT32));
90+
8291
WORD32 p_inp_shape[2];
8392
p_inp_shape[0] = n;
8493
p_inp_shape[1] = p;
@@ -109,19 +118,20 @@ Tensor& mm_out(
109118

110119
const FLOAT32* __restrict__ p_vec = (const FLOAT32* __restrict__)p_o;
111120

121+
// mm will always be converted to addmm and to linear, and move transpose to
122+
// graph
112123
WORD32 val = xa_nn_matmul_f32xf32_f32(
113124
p_out,
114125
p_mat1,
115126
p_vec,
116-
NULL,
127+
p_bias_zero,
117128
rows,
118129
cols1,
119130
row_stride1,
120131
vec_count,
121132
vec_offset,
122133
out_offset,
123134
out_stride);
124-
125135
return out;
126136
}
127137

0 commit comments

Comments
 (0)