Skip to content

Commit f2b8b92

Browse files
[OpenVINO backend] support categorical (#21437)
1 parent 2615b5b commit f2b8b92

File tree

3 files changed

+112
-4
lines changed

3 files changed

+112
-4
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,61 @@ MathOpsCorrectnessTest::test_stft3
224224
MathOpsCorrectnessTest::test_stft4
225225
MathOpsCorrectnessTest::test_stft5
226226
MathOpsCorrectnessTest::test_stft6
227+
RandomCorrectnessTest::test_beta0
228+
RandomCorrectnessTest::test_beta1
229+
RandomCorrectnessTest::test_beta2
230+
RandomCorrectnessTest::test_binomial0
231+
RandomCorrectnessTest::test_binomial1
232+
RandomCorrectnessTest::test_binomial2
233+
RandomCorrectnessTest::test_dropout
234+
RandomCorrectnessTest::test_dropout_noise_shape
235+
RandomCorrectnessTest::test_gamma0
236+
RandomCorrectnessTest::test_gamma1
237+
RandomCorrectnessTest::test_gamma2
238+
RandomCorrectnessTest::test_randint0
239+
RandomCorrectnessTest::test_randint1
240+
RandomCorrectnessTest::test_randint2
241+
RandomCorrectnessTest::test_randint3
242+
RandomCorrectnessTest::test_randint4
243+
RandomCorrectnessTest::test_shuffle
244+
RandomCorrectnessTest::test_truncated_normal0
245+
RandomCorrectnessTest::test_truncated_normal1
246+
RandomCorrectnessTest::test_truncated_normal2
247+
RandomCorrectnessTest::test_truncated_normal3
248+
RandomCorrectnessTest::test_truncated_normal4
249+
RandomCorrectnessTest::test_truncated_normal5
250+
RandomCorrectnessTest::test_uniform0
251+
RandomCorrectnessTest::test_uniform1
252+
RandomCorrectnessTest::test_uniform2
253+
RandomCorrectnessTest::test_uniform3
254+
RandomCorrectnessTest::test_uniform4
255+
RandomBehaviorTest::test_beta_tf_data_compatibility
256+
RandomDTypeTest::test_beta_bfloat16
257+
RandomDTypeTest::test_beta_float16
258+
RandomDTypeTest::test_beta_float32
259+
RandomDTypeTest::test_beta_float64
260+
RandomDTypeTest::test_binomial_bfloat16
261+
RandomDTypeTest::test_binomial_float16
262+
RandomDTypeTest::test_binomial_float32
263+
RandomDTypeTest::test_binomial_float64
264+
RandomDTypeTest::test_dropout_bfloat16
265+
RandomDTypeTest::test_dropout_float16
266+
RandomDTypeTest::test_dropout_float32
267+
RandomDTypeTest::test_dropout_float64
268+
RandomDTypeTest::test_gamma_bfloat16
269+
RandomDTypeTest::test_gamma_float16
270+
RandomDTypeTest::test_gamma_float32
271+
RandomDTypeTest::test_gamma_float64
272+
RandomDTypeTest::test_normal_bfloat16
273+
RandomDTypeTest::test_randint_int16
274+
RandomDTypeTest::test_randint_int32
275+
RandomDTypeTest::test_randint_int64
276+
RandomDTypeTest::test_randint_int8
277+
RandomDTypeTest::test_randint_uint16
278+
RandomDTypeTest::test_randint_uint32
279+
RandomDTypeTest::test_randint_uint8
280+
RandomDTypeTest::test_truncated_normal_bfloat16
281+
RandomDTypeTest::test_uniform_bfloat16
227282
SegmentSumTest::test_segment_sum_call
228283
SegmentMaxTest::test_segment_max_call
229284
TestMathErrors::test_invalid_fft_length

keras/src/backend/openvino/excluded_tests.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ keras/src/ops/linalg_test.py
3232
keras/src/ops/nn_test.py
3333
keras/src/optimizers
3434
keras/src/quantizers
35-
keras/src/random
35+
keras/src/random/seed_generator_test.py
3636
keras/src/regularizers
3737
keras/src/saving
3838
keras/src/trainers

keras/src/backend/openvino/random.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from keras.src.backend.openvino.core import OPENVINO_DTYPES
77
from keras.src.backend.openvino.core import OpenVINOKerasTensor
88
from keras.src.backend.openvino.core import convert_to_numpy
9+
from keras.src.backend.openvino.core import get_ov_output
910
from keras.src.random.seed_generator import SeedGenerator
1011
from keras.src.random.seed_generator import draw_seed
1112
from keras.src.random.seed_generator import make_default_seed
@@ -39,9 +40,61 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
3940

4041

4142
def categorical(logits, num_samples, dtype="int64", seed=None):
42-
raise NotImplementedError(
43-
"`categorical` is not supported with openvino backend"
44-
)
43+
dtype = dtype or "int64"
44+
ov_dtype = OPENVINO_DTYPES[dtype]
45+
logits = get_ov_output(logits)
46+
47+
zero_const = ov_opset.constant(0, Type.i32).output(0)
48+
one_const = ov_opset.constant(1, Type.i32).output(0)
49+
neg_one_const = ov_opset.constant(-1, Type.i32).output(0)
50+
51+
# Compute probabilities and cumulative sum
52+
probs = ov_opset.softmax(logits, axis=-1).output(0)
53+
cumsum_probs = ov_opset.cumsum(probs, neg_one_const).output(0)
54+
55+
# Get shape and compute batch dimensions
56+
logits_shape = ov_opset.shape_of(logits, Type.i32).output(0)
57+
rank = ov_opset.shape_of(logits_shape, Type.i32).output(0)
58+
rank_scalar = ov_opset.squeeze(rank, zero_const).output(0)
59+
rank_minus_1 = ov_opset.subtract(rank_scalar, one_const).output(0)
60+
61+
# Extract batch shape (all dimensions except last)
62+
batch_indices = ov_opset.range(
63+
zero_const, rank_minus_1, one_const, output_type=Type.i32
64+
).output(0)
65+
batch_shape = ov_opset.gather(logits_shape, batch_indices, axis=0).output(0)
66+
67+
# Create final shape [batch_dims..., num_samples]
68+
num_samples_const = ov_opset.constant([num_samples], Type.i32).output(0)
69+
final_shape = ov_opset.concat(
70+
[batch_shape, num_samples_const], axis=0
71+
).output(0)
72+
73+
seed_tensor = draw_seed(seed)
74+
if isinstance(seed_tensor, OpenVINOKerasTensor):
75+
seed1, seed2 = convert_to_numpy(seed_tensor)
76+
else:
77+
seed1, seed2 = seed_tensor.data
78+
79+
probs_dtype = probs.get_element_type()
80+
zero_float = ov_opset.constant(0.0, probs_dtype).output(0)
81+
one_float = ov_opset.constant(1.0, probs_dtype).output(0)
82+
83+
rand = ov_opset.random_uniform(
84+
final_shape, zero_float, one_float, probs_dtype, seed1, seed2
85+
).output(0)
86+
87+
rand_unsqueezed = ov_opset.unsqueeze(rand, neg_one_const).output(0)
88+
cumsum_unsqueezed = ov_opset.unsqueeze(cumsum_probs, one_const).output(0)
89+
90+
# Count how many cumulative probabilities each random number exceeds
91+
greater = ov_opset.greater(rand_unsqueezed, cumsum_unsqueezed).output(0)
92+
samples = ov_opset.reduce_sum(
93+
ov_opset.convert(greater, Type.i32).output(0), neg_one_const
94+
).output(0)
95+
96+
result = ov_opset.convert(samples, ov_dtype).output(0)
97+
return OpenVINOKerasTensor(result)
4598

4699

47100
def randint(shape, minval, maxval, dtype="int32", seed=None):

0 commit comments

Comments
 (0)