Skip to content

Commit b6f178e

Browse files
authored
Fix LogSumExp Implementation for OpenVINO backend. (#21480)
* Update math.py * Update excluded_concrete_tests.txt
1 parent 503bcf5 commit b6f178e

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,6 @@ MathOpsCorrectnessTest::test_istft4
211211
MathOpsCorrectnessTest::test_istft5
212212
MathOpsCorrectnessTest::test_istft6
213213
MathOpsCorrectnessTest::test_logdet
214-
MathOpsCorrectnessTest::test_logsumexp
215214
MathOpsCorrectnessTest::test_rfft0
216215
MathOpsCorrectnessTest::test_rfft1
217216
MathOpsCorrectnessTest::test_rfft2
@@ -238,4 +237,4 @@ TestMathErrors::test_invalid_fft_length
238237
TestMathErrors::test_istft_invalid_window_shape_2D_inputs
239238
TestMathErrors::test_stft_invalid_input_type
240239
TestMathErrors::test_stft_invalid_window
241-
TestMathErrors::test_stft_invalid_window_shape
240+
TestMathErrors::test_stft_invalid_window_shape

keras/src/backend/openvino/math.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,17 @@ def logsumexp(x, axis=None, keepdims=False):
4444
axis = list(axis)
4545
axis = ov_opset.constant(axis, Type.i32).output(0)
4646
const_zero = ov_opset.constant(0, x.get_element_type()).output(0)
47-
reduce_max = ov_opset.reduce_max(x, axis, keepdims).output(0)
47+
# Use keepdims=True for reduce_max to ensure proper broadcasting
48+
reduce_max = ov_opset.reduce_max(x, axis, True).output(0)
4849
is_finite = ov_opset.is_finite(reduce_max).output(0)
4950
norm_max = ov_opset.select(is_finite, reduce_max, const_zero).output(0)
5051
norm_max_sub = ov_opset.subtract(x, norm_max).output(0)
5152
exp_norm_max = ov_opset.exp(norm_max_sub).output(0)
5253
sum_exp = ov_opset.reduce_sum(exp_norm_max, axis, keepdims).output(0)
5354
log_sum_exp = ov_opset.log(sum_exp).output(0)
55+
# Squeeze norm_max if needed to match dimensions
56+
if not keepdims:
57+
norm_max = ov_opset.squeeze(norm_max, axis).output(0)
5458
log_sum_exp = ov_opset.add(norm_max, log_sum_exp).output(0)
5559
return OpenVINOKerasTensor(log_sum_exp)
5660

0 commit comments

Comments
 (0)