@@ -1210,7 +1210,7 @@ defmodule EXLA.Defn do
1210
1210
Value . dynamic_update_slice ( tensor , slice , start_indices )
1211
1211
end
1212
1212
1213
- defp to_operator ( :take , [ % mod { } = tensor , indices , axis ] , _ans , _state ) do
1213
+ defp to_operator ( :take , [ % Value { } = tensor , indices , axis ] , _ans , _state ) do
1214
1214
tensor_rank = tensor |> op_shape ( ) |> tuple_size ( )
1215
1215
indices_rank = indices |> op_shape ( ) |> tuple_size ( )
1216
1216
result_rank = tensor_rank - 1 + indices_rank
@@ -1221,7 +1221,7 @@ defmodule EXLA.Defn do
1221
1221
collapsed_slice_dims = [ axis ]
1222
1222
start_index_map = [ axis ]
1223
1223
1224
- mod . gather (
1224
+ Value . gather (
1225
1225
tensor ,
1226
1226
indices ,
1227
1227
index_vector_dim ,
@@ -1232,7 +1232,7 @@ defmodule EXLA.Defn do
1232
1232
)
1233
1233
end
1234
1234
1235
- defp to_operator ( :take_along_axis , [ % mod { } = tensor , indices , axis ] , _ans , state ) do
1235
+ defp to_operator ( :take_along_axis , [ % Value { } = tensor , indices , axis ] , _ans , state ) do
1236
1236
indices_shape = op_shape ( indices )
1237
1237
indices_rank = tuple_size ( indices_shape )
1238
1238
@@ -1244,22 +1244,22 @@ defmodule EXLA.Defn do
1244
1244
collapsed_slice_dims = Enum . to_list ( axes_range )
1245
1245
start_index_map = Enum . to_list ( axes_range )
1246
1246
1247
- indices_exla_shape = mod . get_shape ( indices )
1247
+ indices_exla_shape = Value . get_shape ( indices )
1248
1248
1249
1249
iotas =
1250
1250
Enum . map ( axes_range , fn axis ->
1251
- mod . iota ( state . builder , indices_exla_shape , axis )
1251
+ Value . iota ( state . builder , indices_exla_shape , axis )
1252
1252
end )
1253
1253
1254
1254
new_axis_shape = Tuple . append ( indices_shape , 1 )
1255
1255
1256
1256
indices =
1257
1257
iotas
1258
1258
|> List . replace_at ( axis , indices )
1259
- |> Enum . map ( & mod . reshape ( & 1 , new_axis_shape ) )
1260
- |> mod . concatenate ( indices_rank )
1259
+ |> Enum . map ( & Value . reshape ( & 1 , new_axis_shape ) )
1260
+ |> Value . concatenate ( indices_rank )
1261
1261
1262
- mod . gather (
1262
+ Value . gather (
1263
1263
tensor ,
1264
1264
indices ,
1265
1265
index_vector_dim ,
@@ -1270,7 +1270,7 @@ defmodule EXLA.Defn do
1270
1270
)
1271
1271
end
1272
1272
1273
- defp to_operator ( :gather , [ % mod { } = tensor , indices , opts ] , _ans , _state ) do
1273
+ defp to_operator ( :gather , [ % Value { } = tensor , indices , opts ] , _ans , _state ) do
1274
1274
axes = Keyword . fetch! ( opts , :axes )
1275
1275
tensor_shape = op_shape ( tensor )
1276
1276
tensor_rank = tuple_size ( tensor_shape )
@@ -1284,7 +1284,7 @@ defmodule EXLA.Defn do
1284
1284
1285
1285
batch_size = tensor_rank - length ( axes )
1286
1286
offset_dims = count_up ( batch_size , batch_size )
1287
- mod . gather ( tensor , indices , index_vector_dim , slice_sizes , offset_dims , axes , axes )
1287
+ Value . gather ( tensor , indices , index_vector_dim , slice_sizes , offset_dims , axes , axes )
1288
1288
end
1289
1289
1290
1290
defp to_operator ( :reverse , [ % Value { } = tensor , axes ] , _ans , _state ) do
@@ -1339,7 +1339,7 @@ defmodule EXLA.Defn do
1339
1339
EXLA.Lib . argsort ( state . builder , tensor , dimension , stable , comp , ans . type )
1340
1340
end
1341
1341
1342
- defp fft ( exla_op , [ % mod { } = tensor , opts ] , % { type: type } , state ) do
1342
+ defp fft ( exla_op , [ % Value { } = tensor , opts ] , % { type: type } , state ) do
1343
1343
n = opts [ :length ]
1344
1344
axis = opts [ :axis ]
1345
1345
output_type = Nx.Type . to_complex ( type )
@@ -1362,15 +1362,15 @@ defmodule EXLA.Defn do
1362
1362
|> List . to_tuple ( )
1363
1363
1364
1364
tensor
1365
- |> mod . transpose ( permutation )
1365
+ |> Value . transpose ( permutation )
1366
1366
|> exla_op . ( [ n ] )
1367
- |> mod . transpose ( permutation )
1367
+ |> Value . transpose ( permutation )
1368
1368
else
1369
1369
exla_op . ( tensor , [ n ] )
1370
1370
end
1371
1371
end
1372
1372
1373
- defp fft2 ( exla_op , [ % mod { } = tensor , opts ] , % { type: type } , state ) do
1373
+ defp fft2 ( exla_op , [ % Value { } = tensor , opts ] , % { type: type } , state ) do
1374
1374
[ l1 , l2 ] = lengths = opts [ :lengths ]
1375
1375
[ ax1 , ax2 ] = axes = opts [ :axes ]
1376
1376
output_type = Nx.Type . to_complex ( type )
@@ -1399,9 +1399,9 @@ defmodule EXLA.Defn do
1399
1399
|> List . to_tuple ( )
1400
1400
1401
1401
tensor
1402
- |> mod . transpose ( permutation )
1402
+ |> Value . transpose ( permutation )
1403
1403
|> exla_op . ( lengths )
1404
- |> mod . transpose ( permutation )
1404
+ |> Value . transpose ( permutation )
1405
1405
else
1406
1406
exla_op . ( tensor , lengths )
1407
1407
end
0 commit comments