Skip to content

Commit 510e689

Browse files
committed
Deprecate ~V and ~M in favor of ~VEC and ~MAT
1 parent 7c36e06 commit 510e689

File tree

13 files changed

+319
-297
lines changed

13 files changed

+319
-297
lines changed

exla/test/exla/backend_test.exs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
defmodule EXLA.BackendTest do
22
use EXLA.Case, async: true
33

4-
import Nx, only: [sigil_V: 2]
4+
import Nx, only: [sigil_VEC: 2]
55

66
setup do
77
Nx.default_backend(EXLA.Backend)
@@ -192,7 +192,7 @@ defmodule EXLA.BackendTest do
192192
end
193193

194194
test "conjugate" do
195-
assert inspect(Nx.conjugate(~V[1 2-0i 3+0i 0-i 0-2i])) =~
195+
assert inspect(Nx.conjugate(~VEC[1 2-0i 3+0i 0-i 0-2i])) =~
196196
"1.0-0.0i, 2.0+0.0i, 3.0-0.0i, 0.0+1.0i, 0.0+2.0i"
197197
end
198198
end

exla/test/exla/defn/expr_test.exs

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -815,17 +815,17 @@ defmodule EXLA.Defn.ExprTest do
815815
test "fft" do
816816
assert_all_close(
817817
fft(Nx.tensor([1, 1, 0, 0]), length: 5),
818-
~V[2.0+0.0i 1.3090-0.9511i 0.1909-0.5877i 0.1909+0.5877i 1.3090+0.9510i]
818+
~VEC[2.0+0.0i 1.3090-0.9511i 0.1909-0.5877i 0.1909+0.5877i 1.3090+0.9510i]
819819
)
820820

821821
assert_all_close(
822822
fft(Nx.tensor([1, 1, 0, 0, 2, 3]), length: 4),
823-
~V[2.0+0.0i 1.0-1.0i 0.0+0.0i 1.0+1.0i]
823+
~VEC[2.0+0.0i 1.0-1.0i 0.0+0.0i 1.0+1.0i]
824824
)
825825

826826
assert_all_close(
827827
fft(Nx.tensor([1, 1, 0]), length: :power_of_two),
828-
~V[2.0+0.0i 1.0-1.0i 0.0+0.0i 1.0+1.0i]
828+
~VEC[2.0+0.0i 1.0-1.0i 0.0+0.0i 1.0+1.0i]
829829
)
830830
end
831831

@@ -847,12 +847,12 @@ defmodule EXLA.Defn.ExprTest do
847847
length: :power_of_two
848848
),
849849
Nx.stack([
850-
~M[
850+
~MAT[
851851
2 1.0-1.0i 0 1.0+1.0i
852852
1 1 1 1
853853
1 -1i -1 1i
854854
],
855-
~M[
855+
~MAT[
856856
1 -1i -1 1i
857857
1 1 1 1
858858
2 1.0-1.0i 0 1.0+1.0i
@@ -877,12 +877,12 @@ defmodule EXLA.Defn.ExprTest do
877877
length: 4
878878
),
879879
Nx.stack([
880-
~M[
880+
~MAT[
881881
2 1.0-1.0i 0 1.0+1.0i
882882
1 1 1 1
883883
1 -1i -1 1i
884884
],
885-
~M[
885+
~MAT[
886886
1 -1i -1 1i
887887
1 1 1 1
888888
2 1.0-1.0i 0 1.0+1.0i
@@ -907,12 +907,12 @@ defmodule EXLA.Defn.ExprTest do
907907
length: 4
908908
),
909909
Nx.stack([
910-
~M[
910+
~MAT[
911911
2 1.0-1.0i 0 1.0+1.0i
912912
1 1 1 1
913913
1 -1i -1 1i
914914
],
915-
~M[
915+
~MAT[
916916
1 -1i -1 1i
917917
1 1 1 1
918918
2 1.0-1.0i 0 1.0+1.0i
@@ -923,19 +923,19 @@ defmodule EXLA.Defn.ExprTest do
923923

924924
test "ifft" do
925925
assert_all_close(
926-
ifft(~V[5 5 5 5 5],
926+
ifft(~VEC[5 5 5 5 5],
927927
length: 5
928928
),
929929
Nx.tensor([5, 0, 0, 0, 0])
930930
)
931931

932932
assert_all_close(
933-
ifft(~V[2.0+0.0i 1.0-1.0i 0.0+0.0i 1.0+1.0i 5 6], length: 4),
933+
ifft(~VEC[2.0+0.0i 1.0-1.0i 0.0+0.0i 1.0+1.0i 5 6], length: 4),
934934
Nx.tensor([1, 1, 0, 0])
935935
)
936936

937937
assert_all_close(
938-
ifft(~V[2 0 0], length: :power_of_two),
938+
ifft(~VEC[2 0 0], length: :power_of_two),
939939
Nx.tensor([0.5, 0.5, 0.5, 0.5])
940940
)
941941
end
@@ -944,12 +944,12 @@ defmodule EXLA.Defn.ExprTest do
944944
assert_all_close(
945945
ifft(
946946
Nx.stack([
947-
~M[
947+
~MAT[
948948
2 1.0-1.0i 0 1.0+1.0i
949949
1 1 1 1
950950
1 -1i -1 1i
951951
],
952-
~M[
952+
~MAT[
953953
1 -1i -1 1i
954954
1 1 1 1
955955
2 1.0-1.0i 0 1.0+1.0i
@@ -988,12 +988,12 @@ defmodule EXLA.Defn.ExprTest do
988988
length: 4
989989
),
990990
Nx.stack([
991-
~M[
991+
~MAT[
992992
2 1.0+1.0i 0 1.0-1.0i
993993
1 1 1 1
994994
1 1i -1 -1i
995995
],
996-
~M[
996+
~MAT[
997997
1 1i -1 -1i
998998
1 1 1 1
999999
2 1.0+1.0i 0 1.0-1.0i
@@ -1018,12 +1018,12 @@ defmodule EXLA.Defn.ExprTest do
10181018
length: 4
10191019
),
10201020
Nx.stack([
1021-
~M[
1021+
~MAT[
10221022
2 1.0+1.0i 0 1.0-1.0i
10231023
1 1 1 1
10241024
1 1i -1 -1i
10251025
],
1026-
~M[
1026+
~MAT[
10271027
1 1i -1 -1i
10281028
1 1 1 1
10291029
2 1.0+1.0i 0 1.0-1.0i

exla/test/exla/defn/vectorize_test.exs

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,9 @@ defmodule EXLA.Defn.VectorizeTest do
182182

183183
test "simple if" do
184184
# this tests the case where we have a single vectorized predicate
185-
pred = Nx.vectorize(~V[0 1 0], :pred)
185+
pred = Nx.vectorize(~VEC[0 1 0], :pred)
186186

187-
assert_equal(vectorized_if(pred, 1, 2, pid: self()), Nx.vectorize(~V[2 1 2], :pred))
187+
assert_equal(vectorized_if(pred, 1, 2, pid: self()), Nx.vectorize(~VEC[2 1 2], :pred))
188188

189189
assert_received {:vectorization_test, t, clause: "if"}
190190
assert_equal(t, Nx.tensor(1))
@@ -195,12 +195,12 @@ defmodule EXLA.Defn.VectorizeTest do
195195

196196
test "simple cond" do
197197
# this tests the case where we have a two vectorized predicates
198-
pred1 = Nx.vectorize(~V[1 0 0], :pred)
199-
pred2 = Nx.vectorize(~V[0 0 0], :pred)
198+
pred1 = Nx.vectorize(~VEC[1 0 0], :pred)
199+
pred2 = Nx.vectorize(~VEC[0 0 0], :pred)
200200

201201
assert_equal(
202202
vectorized_cond(pred1, 1, pred2, 2, 3, pid: self()),
203-
Nx.vectorize(~V[1 3 3], :pred)
203+
Nx.vectorize(~VEC[1 3 3], :pred)
204204
)
205205

206206
assert_received {:vectorization_test, t, clause: "clause_1"}
@@ -211,20 +211,20 @@ defmodule EXLA.Defn.VectorizeTest do
211211
end
212212

213213
test "if with container result" do
214-
pred1 = Nx.vectorize(~V[2 0 0], :pred)
214+
pred1 = Nx.vectorize(~VEC[2 0 0], :pred)
215215

216216
result =
217217
vectorized_if(
218218
pred1,
219219
{1, 2, 3},
220-
{7, 8, Nx.vectorize(~V[9 10 11], :x)},
220+
{7, 8, Nx.vectorize(~VEC[9 10 11], :x)},
221221
pid: self()
222222
)
223223

224224
assert_equal(result, {
225-
Nx.vectorize(~V[1 7 7], :pred),
226-
Nx.vectorize(~V[2 8 8], :pred),
227-
Nx.vectorize(~M[
225+
Nx.vectorize(~VEC[1 7 7], :pred),
226+
Nx.vectorize(~VEC[2 8 8], :pred),
227+
Nx.vectorize(~MAT[
228228
3 3 3
229229
9 10 11
230230
9 10 11
@@ -248,8 +248,8 @@ defmodule EXLA.Defn.VectorizeTest do
248248
end
249249

250250
test "only executes selected branches" do
251-
t = Nx.vectorize(~V[1], :pred)
252-
f = Nx.vectorize(~V[0], :pred)
251+
t = Nx.vectorize(~VEC[1], :pred)
252+
f = Nx.vectorize(~VEC[0], :pred)
253253

254254
assert = fn res, val, clause ->
255255
t = Nx.tensor(val)
@@ -267,74 +267,74 @@ defmodule EXLA.Defn.VectorizeTest do
267267

268268
test "1 vectorized pred in the beginning" do
269269
assert_equal(
270-
cond4(Nx.vectorize(~V[0 1], :pred), 10, 0, 20, 0, 30, 40),
271-
Nx.vectorize(~V[40 10], :pred)
270+
cond4(Nx.vectorize(~VEC[0 1], :pred), 10, 0, 20, 0, 30, 40),
271+
Nx.vectorize(~VEC[40 10], :pred)
272272
)
273273

274274
assert_equal(
275-
cond4(Nx.vectorize(~V[0 0], :pred), 10, 1, 20, 0, 30, 40),
276-
Nx.vectorize(~V[20 20], :pred)
275+
cond4(Nx.vectorize(~VEC[0 0], :pred), 10, 1, 20, 0, 30, 40),
276+
Nx.vectorize(~VEC[20 20], :pred)
277277
)
278278

279279
assert_equal(
280-
cond4(Nx.vectorize(~V[0 0], :pred), 10, 0, 20, 1, 30, 40),
281-
Nx.vectorize(~V[30 30], :pred)
280+
cond4(Nx.vectorize(~VEC[0 0], :pred), 10, 0, 20, 1, 30, 40),
281+
Nx.vectorize(~VEC[30 30], :pred)
282282
)
283283

284284
assert_equal(
285-
cond4(Nx.vectorize(~V[0 0], :pred), 10, 0, 20, 0, 30, 40),
286-
Nx.vectorize(~V[40 40], :pred)
285+
cond4(Nx.vectorize(~VEC[0 0], :pred), 10, 0, 20, 0, 30, 40),
286+
Nx.vectorize(~VEC[40 40], :pred)
287287
)
288288
end
289289

290290
test "1 vectorized pred in the second but not last position" do
291291
assert_equal(
292-
cond4(0, 10, Nx.vectorize(~V[0 1], :pred), 20, 0, 30, 40),
293-
Nx.vectorize(~V[40 20], :pred)
292+
cond4(0, 10, Nx.vectorize(~VEC[0 1], :pred), 20, 0, 30, 40),
293+
Nx.vectorize(~VEC[40 20], :pred)
294294
)
295295

296296
assert_equal(
297-
cond4(1, 10, Nx.vectorize(~V[0 1], :pred), 20, 0, 30, 40),
298-
Nx.vectorize(~V[10 10], :pred)
297+
cond4(1, 10, Nx.vectorize(~VEC[0 1], :pred), 20, 0, 30, 40),
298+
Nx.vectorize(~VEC[10 10], :pred)
299299
)
300300

301301
assert_equal(
302-
cond4(0, 10, Nx.vectorize(~V[0 0], :pred), 20, 1, 30, 40),
303-
Nx.vectorize(~V[30 30], :pred)
302+
cond4(0, 10, Nx.vectorize(~VEC[0 0], :pred), 20, 1, 30, 40),
303+
Nx.vectorize(~VEC[30 30], :pred)
304304
)
305305

306306
assert_equal(
307-
cond4(0, 10, Nx.vectorize(~V[0 0], :pred), 20, 0, 30, 40),
308-
Nx.vectorize(~V[40 40], :pred)
307+
cond4(0, 10, Nx.vectorize(~VEC[0 0], :pred), 20, 0, 30, 40),
308+
Nx.vectorize(~VEC[40 40], :pred)
309309
)
310310
end
311311

312312
test "1 vectorized pred in the last position" do
313313
assert_equal(
314-
cond4(0, 10, 0, 20, Nx.vectorize(~V[0 1], :pred), 30, 40),
315-
Nx.vectorize(~V[40 30], :pred)
314+
cond4(0, 10, 0, 20, Nx.vectorize(~VEC[0 1], :pred), 30, 40),
315+
Nx.vectorize(~VEC[40 30], :pred)
316316
)
317317

318318
assert_equal(
319-
cond4(1, 10, 0, 20, Nx.vectorize(~V[0 1], :pred), 30, 40),
320-
Nx.vectorize(~V[10 10], :pred)
319+
cond4(1, 10, 0, 20, Nx.vectorize(~VEC[0 1], :pred), 30, 40),
320+
Nx.vectorize(~VEC[10 10], :pred)
321321
)
322322

323323
assert_equal(
324-
cond4(0, 10, 1, 20, Nx.vectorize(~V[0 1], :pred), 30, 40),
325-
Nx.vectorize(~V[20 20], :pred)
324+
cond4(0, 10, 1, 20, Nx.vectorize(~VEC[0 1], :pred), 30, 40),
325+
Nx.vectorize(~VEC[20 20], :pred)
326326
)
327327

328328
assert_equal(
329-
cond4(0, 10, 0, 20, Nx.vectorize(~V[0 0], :pred), 30, 40),
330-
Nx.vectorize(~V[40 40], :pred)
329+
cond4(0, 10, 0, 20, Nx.vectorize(~VEC[0 0], :pred), 30, 40),
330+
Nx.vectorize(~VEC[40 40], :pred)
331331
)
332332
end
333333

334334
test "2 vectorized preds with different axes" do
335335
assert_equal(
336-
cond4(Nx.vectorize(~V[0 1 0], :pred1), 10, Nx.vectorize(~V[1 0], :pred2), 20, 0, 30, 40),
337-
Nx.vectorize(~M[
336+
cond4(Nx.vectorize(~VEC[0 1 0], :pred1), 10, Nx.vectorize(~VEC[1 0], :pred2), 20, 0, 30, 40),
337+
Nx.vectorize(~MAT[
338338
20 40
339339
10 10
340340
20 40
@@ -345,15 +345,15 @@ defmodule EXLA.Defn.VectorizeTest do
345345
test "2 vectorized preds with different axes + clauses that match either" do
346346
assert_equal(
347347
cond4(
348-
Nx.vectorize(~V[0 1 0], :pred1),
349-
Nx.vectorize(~V[10 100], :pred2),
350-
Nx.vectorize(~V[1 0], :pred2),
351-
Nx.vectorize(~V[20 200 2000], :pred1),
348+
Nx.vectorize(~VEC[0 1 0], :pred1),
349+
Nx.vectorize(~VEC[10 100], :pred2),
350+
Nx.vectorize(~VEC[1 0], :pred2),
351+
Nx.vectorize(~VEC[20 200 2000], :pred1),
352352
0,
353353
30,
354354
40
355355
),
356-
Nx.vectorize(~M[
356+
Nx.vectorize(~MAT[
357357
20 40
358358
10 100
359359
2000 40

nx/guides/advanced/aggregation.livemd

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ max_y = Nx.reduce_max(m, axes: [:y])
7575
Let's consider another example with [Nx.weighted_mean](https://hexdocs.pm/nx/Nx.html#weighted_mean/3). It supports full-tensor and per axis operations. We display how to compute the _weighted mean aggregate_ of a matrix with the example below of a 2D tensor of shape `{2,2}` labeled `m`:
7676

7777
```elixir
78-
m = ~M[
78+
m = ~MAT[
7979
1 2
8080
3 4
8181
]
@@ -96,7 +96,7 @@ m = ~M[
9696
First, we'll compute the full-tensor aggregation. The calculations are developed below. We calculate an "array product" (aka [Hadamard product](<https://en.wikipedia.org/wiki/Hadamard_product_(matrices)#:~:text=In%20mathematics%2C%20the%20Hadamard%20product,elements%20i%2C%20j%20of%20the>), an element-wise product) of our tensor with the tensor of weights, then sum all the elements and divide by the sum of the weights.
9797

9898
```elixir
99-
w = ~M[
99+
w = ~MAT[
100100
10 20
101101
30 40
102102
]
@@ -121,7 +121,7 @@ man_w_avg = (1 * 10 + 2 * 20 + 3 * 30 + 4 * 40) / (10 + 20 + 30 + 40)
121121
The weighted mean can be computed _per axis_. Let's compute it along the _first_ axis (`axes: [0]`): you calculate "by column", so you aggregate/reduce along the first axis:
122122

123123
```elixir
124-
w = ~M[
124+
w = ~MAT[
125125
10 20
126126
30 40
127127
]
@@ -148,7 +148,7 @@ man_w_avg_x = [(1 * 10 + 3 * 30) / (10 + 30), (2 * 20 + 4 * 40) / (20 + 40)]
148148
We calculate weighted mean of a square matrix along the _second_ axis (`axes: [1]`): you calculate per row, so you aggregate/reduce along the second axis.
149149

150150
```elixir
151-
w = ~M[
151+
w = ~MAT[
152152
10 20
153153
30 40
154154
]
@@ -816,7 +816,7 @@ Nx.argmin(t, axis: 3)
816816
You have the `:tie_break` option to decide how to operate with you have several occurences of the result. It defaults to `tie_break: :low`.
817817

818818
```elixir
819-
t4 = ~V[2 0 0 0 1]
819+
t4 = ~VEC[2 0 0 0 1]
820820

821821
%{
822822
argmin_with_default: Nx.argmin(t4) |> Nx.to_number(),

0 commit comments

Comments
 (0)