Skip to content

Commit b0dcf76

Browse files
committed
Fix tests
1 parent 6c06864 commit b0dcf76

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

lib/bumblebee/utils/nx.ex

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -75,22 +75,22 @@ defmodule Bumblebee.Utils.Nx do
7575
iex> [first, second] = Bumblebee.Utils.Nx.batch_to_list(outputs)
7676
iex> first.x
7777
#Nx.Tensor<
78-
s64[2]
78+
s32[2]
7979
[0, 0]
8080
>
8181
iex> second.x
8282
#Nx.Tensor<
83-
s64[2]
83+
s32[2]
8484
[1, 1]
8585
>
8686
iex> first.y
8787
#Nx.Tensor<
88-
s64
88+
s32
8989
0
9090
>
9191
iex> second.y
9292
#Nx.Tensor<
93-
s64
93+
s32
9494
1
9595
>
9696
@@ -122,7 +122,7 @@ defmodule Bumblebee.Utils.Nx do
122122
iex> result = Bumblebee.Utils.Nx.composite_concatenate(left, right)
123123
iex> result.x
124124
#Nx.Tensor<
125-
s64[4][2]
125+
s32[4][2]
126126
[
127127
[0, 0],
128128
[1, 1],
@@ -132,7 +132,7 @@ defmodule Bumblebee.Utils.Nx do
132132
>
133133
iex> result.y
134134
#Nx.Tensor<
135-
s64[4]
135+
s32[4]
136136
[0, 1, 2, 3]
137137
>
138138
@@ -164,7 +164,7 @@ defmodule Bumblebee.Utils.Nx do
164164
iex> result = Bumblebee.Utils.Nx.composite_unflatten_batch(output, 2)
165165
iex> result.x
166166
#Nx.Tensor<
167-
s64[2][1][2]
167+
s32[2][1][2]
168168
[
169169
[
170170
[0, 0]
@@ -176,7 +176,7 @@ defmodule Bumblebee.Utils.Nx do
176176
>
177177
iex> result.y
178178
#Nx.Tensor<
179-
s64[2][1]
179+
s32[2][1]
180180
[
181181
[0],
182182
[1]
@@ -205,12 +205,12 @@ defmodule Bumblebee.Utils.Nx do
205205
iex> result = Bumblebee.Utils.Nx.composite_flatten_batch(output)
206206
iex> result.x
207207
#Nx.Tensor<
208-
s64[4]
208+
s32[4]
209209
[0, 0, 1, 1]
210210
>
211211
iex> result.y
212212
#Nx.Tensor<
213-
s64[2]
213+
s32[2]
214214
[0, 1]
215215
>
216216
@@ -249,7 +249,7 @@ defmodule Bumblebee.Utils.Nx do
249249
iex> idx = Nx.tensor([[1, 0], [1, 1]])
250250
iex> Bumblebee.Utils.Nx.batched_take(t, idx)
251251
#Nx.Tensor<
252-
s64[2][2][2]
252+
s32[2][2][2]
253253
[
254254
[
255255
[2, 2],
@@ -348,7 +348,7 @@ defmodule Bumblebee.Utils.Nx do
348348
iex> x = Nx.tensor([[1, 2], [3, 4]])
349349
iex> Bumblebee.Utils.Nx.repeat_interleave(x, 2)
350350
#Nx.Tensor<
351-
s64[4][2]
351+
s32[4][2]
352352
[
353353
[1, 2],
354354
[1, 2],
@@ -387,7 +387,7 @@ defmodule Bumblebee.Utils.Nx do
387387
iex> x = Nx.tensor([[1, 1], [2, 2], [3, 3], [4, 4]])
388388
iex> Bumblebee.Utils.Nx.chunked_take(x, 2, Nx.tensor([1, 0]))
389389
#Nx.Tensor<
390-
s64[2][2]
390+
s32[2][2]
391391
[
392392
[2, 2],
393393
[3, 3]
@@ -427,7 +427,7 @@ defmodule Bumblebee.Utils.Nx do
427427
iex> x = Nx.iota({3, 3})
428428
iex> Bumblebee.Utils.Nx.roll(x, shifts: [1], axes: [0])
429429
#Nx.Tensor<
430-
s64[3][3]
430+
s32[3][3]
431431
[
432432
[6, 7, 8],
433433
[0, 1, 2],
@@ -438,7 +438,7 @@ defmodule Bumblebee.Utils.Nx do
438438
iex> x = Nx.iota({3, 3})
439439
iex> Bumblebee.Utils.Nx.roll(x, shifts: [-1], axes: [0])
440440
#Nx.Tensor<
441-
s64[3][3]
441+
s32[3][3]
442442
[
443443
[3, 4, 5],
444444
[6, 7, 8],
@@ -449,7 +449,7 @@ defmodule Bumblebee.Utils.Nx do
449449
iex> x = Nx.iota({3, 3})
450450
iex> Bumblebee.Utils.Nx.roll(x, shifts: [1, 2], axes: [0, 1])
451451
#Nx.Tensor<
452-
s64[3][3]
452+
s32[3][3]
453453
[
454454
[7, 8, 6],
455455
[1, 2, 0],

0 commit comments

Comments
 (0)