@@ -75,22 +75,22 @@ defmodule Bumblebee.Utils.Nx do
7575 iex> [first, second] = Bumblebee.Utils.Nx.batch_to_list(outputs)
7676 iex> first.x
7777 #Nx.Tensor<
78- s64 [2]
78+ s32 [2]
7979 [0, 0]
8080 >
8181 iex> second.x
8282 #Nx.Tensor<
83- s64 [2]
83+ s32 [2]
8484 [1, 1]
8585 >
8686 iex> first.y
8787 #Nx.Tensor<
88- s64
88+ s32
8989 0
9090 >
9191 iex> second.y
9292 #Nx.Tensor<
93- s64
93+ s32
9494 1
9595 >
9696
@@ -122,7 +122,7 @@ defmodule Bumblebee.Utils.Nx do
122122 iex> result = Bumblebee.Utils.Nx.composite_concatenate(left, right)
123123 iex> result.x
124124 #Nx.Tensor<
125- s64 [4][2]
125+ s32 [4][2]
126126 [
127127 [0, 0],
128128 [1, 1],
@@ -132,7 +132,7 @@ defmodule Bumblebee.Utils.Nx do
132132 >
133133 iex> result.y
134134 #Nx.Tensor<
135- s64 [4]
135+ s32 [4]
136136 [0, 1, 2, 3]
137137 >
138138
@@ -164,7 +164,7 @@ defmodule Bumblebee.Utils.Nx do
164164 iex> result = Bumblebee.Utils.Nx.composite_unflatten_batch(output, 2)
165165 iex> result.x
166166 #Nx.Tensor<
167- s64 [2][1][2]
167+ s32 [2][1][2]
168168 [
169169 [
170170 [0, 0]
@@ -176,7 +176,7 @@ defmodule Bumblebee.Utils.Nx do
176176 >
177177 iex> result.y
178178 #Nx.Tensor<
179- s64 [2][1]
179+ s32 [2][1]
180180 [
181181 [0],
182182 [1]
@@ -205,12 +205,12 @@ defmodule Bumblebee.Utils.Nx do
205205 iex> result = Bumblebee.Utils.Nx.composite_flatten_batch(output)
206206 iex> result.x
207207 #Nx.Tensor<
208- s64 [4]
208+ s32 [4]
209209 [0, 0, 1, 1]
210210 >
211211 iex> result.y
212212 #Nx.Tensor<
213- s64 [2]
213+ s32 [2]
214214 [0, 1]
215215 >
216216
@@ -249,7 +249,7 @@ defmodule Bumblebee.Utils.Nx do
249249 iex> idx = Nx.tensor([[1, 0], [1, 1]])
250250 iex> Bumblebee.Utils.Nx.batched_take(t, idx)
251251 #Nx.Tensor<
252- s64 [2][2][2]
252+ s32 [2][2][2]
253253 [
254254 [
255255 [2, 2],
@@ -348,7 +348,7 @@ defmodule Bumblebee.Utils.Nx do
348348 iex> x = Nx.tensor([[1, 2], [3, 4]])
349349 iex> Bumblebee.Utils.Nx.repeat_interleave(x, 2)
350350 #Nx.Tensor<
351- s64 [4][2]
351+ s32 [4][2]
352352 [
353353 [1, 2],
354354 [1, 2],
@@ -387,7 +387,7 @@ defmodule Bumblebee.Utils.Nx do
387387 iex> x = Nx.tensor([[1, 1], [2, 2], [3, 3], [4, 4]])
388388 iex> Bumblebee.Utils.Nx.chunked_take(x, 2, Nx.tensor([1, 0]))
389389 #Nx.Tensor<
390- s64 [2][2]
390+ s32 [2][2]
391391 [
392392 [2, 2],
393393 [3, 3]
@@ -427,7 +427,7 @@ defmodule Bumblebee.Utils.Nx do
427427 iex> x = Nx.iota({3, 3})
428428 iex> Bumblebee.Utils.Nx.roll(x, shifts: [1], axes: [0])
429429 #Nx.Tensor<
430- s64 [3][3]
430+ s32 [3][3]
431431 [
432432 [6, 7, 8],
433433 [0, 1, 2],
@@ -438,7 +438,7 @@ defmodule Bumblebee.Utils.Nx do
438438 iex> x = Nx.iota({3, 3})
439439 iex> Bumblebee.Utils.Nx.roll(x, shifts: [-1], axes: [0])
440440 #Nx.Tensor<
441- s64 [3][3]
441+ s32 [3][3]
442442 [
443443 [3, 4, 5],
444444 [6, 7, 8],
@@ -449,7 +449,7 @@ defmodule Bumblebee.Utils.Nx do
449449 iex> x = Nx.iota({3, 3})
450450 iex> Bumblebee.Utils.Nx.roll(x, shifts: [1, 2], axes: [0, 1])
451451 #Nx.Tensor<
452- s64 [3][3]
452+ s32 [3][3]
453453 [
454454 [7, 8, 6],
455455 [1, 2, 0],
0 commit comments