Skip to content

Commit 1a53072

Browse files
Simplify TensorType.shape format in str output
1 parent b3a6911 commit 1a53072

File tree

11 files changed

+79
-60
lines changed

11 files changed

+79
-60
lines changed

aesara/tensor/type.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,18 @@ def __str__(self):
386386
if self.name:
387387
return self.name
388388
else:
389-
return f"TensorType({self.dtype}, {self.shape})"
389+
390+
def shape_str(s):
391+
if s is None:
392+
return "?"
393+
else:
394+
return str(s)
395+
396+
formatted_shape = ", ".join([shape_str(s) for s in self.shape])
397+
if len(self.shape) == 1:
398+
formatted_shape += ","
399+
400+
return f"TensorType({self.dtype}, ({formatted_shape}))"
390401

391402
def __repr__(self):
392403
return str(self)

doc/extending/graphstructures.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ For example, :ref:`aesara.tensor.irow <libdoc_tensor_creation>` is an instance o
217217

218218
>>> from aesara.tensor import irow
219219
>>> irow()
220-
<TensorType(int32, (1, None))>
220+
<TensorType(int32, (1, ?))>
221221

222222
As the string print-out shows, `irow` specifies the following information about
223223
the :class:`Variable`\s it constructs:

doc/extending/type.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ For example, let's say we have two :class:`Variable`\s with the following
9090
>>> from aesara.tensor.type import TensorType
9191
>>> v1 = TensorType("float64", (2, None))()
9292
>>> v1.type
93-
TensorType(float64, (2, None))
93+
TensorType(float64, (2, ?))
9494
>>> v2 = TensorType("float64", (2, 1))()
9595
>>> v2.type
9696
TensorType(float64, (2, 1))
@@ -145,7 +145,7 @@ SpecifyShape.0
145145
>>> import aesara
146146
>>> aesara.dprint(v3, print_type=True)
147147
SpecifyShape [id A] <TensorType(float64, (2, 1))>
148-
|<TensorType(float64, (2, None))> [id B] <TensorType(float64, (2, None))>
148+
|<TensorType(float64, (2, ?))> [id B] <TensorType(float64, (2, ?))>
149149
|TensorConstant{2} [id C] <TensorType(int8, ())>
150150
|TensorConstant{1} [id D] <TensorType(int8, ())>
151151

doc/library/scan.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ Using the original Gibbs sampling example, with ``strict=True`` added to the
406406
Traceback (most recent call last):
407407
...
408408
MissingInputError: An input of the graph, used to compute
409-
DimShuffle{1,0}(<TensorType(float64, (None, None))>), was not provided and
409+
DimShuffle{1,0}(<TensorType(float64, (?, ?))>), was not provided and
410410
not given a value.Use the Aesara flag exception_verbosity='high',for
411411
more information on this error.
412412

doc/tutorial/debug_faq.rst

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ Running the code above we see:
4444
Traceback (most recent call last):
4545
...
4646
ValueError: Input dimension mismatch. (input[0].shape[0] = 3, input[1].shape[0] = 2)
47-
Apply node that caused the error: Elemwise{add,no_inplace}(<TensorType(float64, (None,))>, <TensorType(float64, (None,))>, <TensorType(float64, (None,))>)
48-
Inputs types: [TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(float64, (None,))]
47+
Apply node that caused the error: Elemwise{add,no_inplace}(<TensorType(float64, (?,))>, <TensorType(float64, (?,))>, <TensorType(float64, (?,))>)
48+
Inputs types: [TensorType(float64, (?,)), TensorType(float64, (?,)), TensorType(float64, (?,))]
4949
Inputs shapes: [(3,), (2,), (2,)]
5050
Inputs strides: [(8,), (8,), (8,)]
5151
Inputs scalar values: ['not scalar', 'not scalar', 'not scalar']
@@ -73,11 +73,11 @@ message becomes :
7373
z = z + y
7474
7575
Debugprint of the apply node:
76-
Elemwise{add,no_inplace} [id A] <TensorType(float64, (None,))> ''
77-
|Elemwise{add,no_inplace} [id B] <TensorType(float64, (None,))> ''
78-
| |<TensorType(float64, (None,))> [id C] <TensorType(float64, (None,))>
79-
| |<TensorType(float64, (None,))> [id C] <TensorType(float64, (None,))>
80-
|<TensorType(float64, (None,))> [id D] <TensorType(float64, (None,))>
76+
Elemwise{add,no_inplace} [id A] <TensorType(float64, (?,))> ''
77+
|Elemwise{add,no_inplace} [id B] <TensorType(float64, (?,))> ''
78+
| |<TensorType(float64, (?,))> [id C] <TensorType(float64, (?,))>
79+
| |<TensorType(float64, (?,))> [id C] <TensorType(float64, (?,))>
80+
|<TensorType(float64, (?,))> [id D] <TensorType(float64, (?,))>
8181
8282
We can here see that the error can be traced back to the line ``z = z + y``.
8383
For this example, using ``optimizer=fast_compile`` worked. If it did not,
@@ -145,18 +145,18 @@ Running the above code generates the following error message:
145145
outputs = self.vm()
146146
ValueError: Shape mismatch: x has 10 cols (and 5 rows) but y has 20 rows (and 10 cols)
147147
Apply node that caused the error: Dot22(x, DimShuffle{1,0}.0)
148-
Inputs types: [TensorType(float64, (None, None)), TensorType(float64, (None, None))]
148+
Inputs types: [TensorType(float64, (?, ?)), TensorType(float64, (?, ?))]
149149
Inputs shapes: [(5, 10), (20, 10)]
150150
Inputs strides: [(80, 8), (8, 160)]
151151
Inputs scalar values: ['not scalar', 'not scalar']
152152

153153
Debugprint of the apply node:
154-
Dot22 [id A] <TensorType(float64, (None, None))> ''
155-
|x [id B] <TensorType(float64, (None, None))>
156-
|DimShuffle{1,0} [id C] <TensorType(float64, (None, None))> ''
157-
|Flatten{2} [id D] <TensorType(float64, (None, None))> ''
158-
|DimShuffle{2,0,1} [id E] <TensorType(float64, (None, None, None))> ''
159-
|W1 [id F] <TensorType(float64, (None, None, None))>
154+
Dot22 [id A] <TensorType(float64, (?, ?))> ''
155+
|x [id B] <TensorType(float64, (?, ?))>
156+
|DimShuffle{1,0} [id C] <TensorType(float64, (?, ?))> ''
157+
|Flatten{2} [id D] <TensorType(float64, (?, ?))> ''
158+
|DimShuffle{2,0,1} [id E] <TensorType(float64, (?, ?, ?))> ''
159+
|W1 [id F] <TensorType(float64, (?, ?, ?))>
160160
161161
HINT: Re-running with most Aesara optimization disabled could give you a back-traces when this node was created. This can be done with by setting the Aesara flags 'optimizer=fast_compile'. If that does not work, Aesara optimization can be disabled with 'optimizer=None'.
162162

@@ -483,7 +483,7 @@ Consider this example script (``ex.py``):
483483
ValueError: Input dimension mismatch. (input[0].shape[0] = 3, input[1].shape[0] = 5)
484484
Apply node that caused the error: Elemwise{mul,no_inplace}(a, b)
485485
Toposort index: 0
486-
Inputs types: [TensorType(float64, (None, None)), TensorType(float64, (None, None))]
486+
Inputs types: [TensorType(float64, (?, ?)), TensorType(float64, (?, ?))]
487487
Inputs shapes: [(3, 4), (5, 5)]
488488
Inputs strides: [(32, 8), (40, 8)]
489489
Inputs values: ['not shown', 'not shown']

tests/compile/test_builders.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -580,10 +580,10 @@ def test_debugprint():
580580
581581
OpFromGraph{inline=False} [id A]
582582
>Elemwise{add,no_inplace} [id E]
583-
> |*0-<TensorType(float64, (None, None))> [id F]
583+
> |*0-<TensorType(float64, (?, ?))> [id F]
584584
> |Elemwise{mul,no_inplace} [id G]
585-
> |*1-<TensorType(float64, (None, None))> [id H]
586-
> |*2-<TensorType(float64, (None, None))> [id I]
585+
> |*1-<TensorType(float64, (?, ?))> [id H]
586+
> |*2-<TensorType(float64, (?, ?))> [id I]
587587
"""
588588

589589
for truth, out in zip(exp_res.split("\n"), lines):

tests/scan/test_printing.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def test_debugprint_sitsot():
5858
5959
for{cpu,scan_fn} [id C] (outer_out_sit_sot-0)
6060
>Elemwise{mul,no_inplace} [id W] (inner_out_sit_sot-0)
61-
> |*0-<TensorType(float64, (None,))> [id X] -> [id E] (inner_in_sit_sot-0)
62-
> |*1-<TensorType(float64, (None,))> [id Y] -> [id M] (inner_in_non_seqs-0)"""
61+
> |*0-<TensorType(float64, (?,))> [id X] -> [id E] (inner_in_sit_sot-0)
62+
> |*1-<TensorType(float64, (?,))> [id Y] -> [id M] (inner_in_non_seqs-0)"""
6363

6464
for truth, out in zip(expected_output.split("\n"), lines):
6565
assert truth.strip() == out.strip()
@@ -113,8 +113,8 @@ def test_debugprint_sitsot_no_extra_info():
113113
114114
for{cpu,scan_fn} [id C]
115115
>Elemwise{mul,no_inplace} [id W]
116-
> |*0-<TensorType(float64, (None,))> [id X] -> [id E]
117-
> |*1-<TensorType(float64, (None,))> [id Y] -> [id M]"""
116+
> |*0-<TensorType(float64, (?,))> [id X] -> [id E]
117+
> |*1-<TensorType(float64, (?,))> [id Y] -> [id M]"""
118118

119119
for truth, out in zip(expected_output.split("\n"), lines):
120120
assert truth.strip() == out.strip()
@@ -264,7 +264,7 @@ def compute_A_k(A, k):
264264
> | | | | | | | |Unbroadcast{0} [id BL]
265265
> | | | | | | | |InplaceDimShuffle{x,0} [id BM]
266266
> | | | | | | | |Elemwise{second,no_inplace} [id BN]
267-
> | | | | | | | |*2-<TensorType(float64, (None,))> [id BO] -> [id W] (inner_in_non_seqs-0)
267+
> | | | | | | | |*2-<TensorType(float64, (?,))> [id BO] -> [id W] (inner_in_non_seqs-0)
268268
> | | | | | | | |InplaceDimShuffle{x} [id BP]
269269
> | | | | | | | |TensorConstant{1.0} [id BQ]
270270
> | | | | | | |ScalarConstant{0} [id BR]
@@ -275,16 +275,16 @@ def compute_A_k(A, k):
275275
> | | | | |Unbroadcast{0} [id BL]
276276
> | | | | |ScalarFromTensor [id BV]
277277
> | | | | |Subtensor{int64} [id BJ]
278-
> | | | |*2-<TensorType(float64, (None,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
278+
> | | | |*2-<TensorType(float64, (?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
279279
> | | |ScalarConstant{1} [id BW]
280280
> | |ScalarConstant{-1} [id BX]
281281
> |InplaceDimShuffle{x} [id BY]
282282
> |*1-<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1)
283283
284284
for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
285285
>Elemwise{mul,no_inplace} [id CA] (inner_out_sit_sot-0)
286-
> |*0-<TensorType(float64, (None,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
287-
> |*1-<TensorType(float64, (None,))> [id CC] -> [id BO] (inner_in_non_seqs-0)"""
286+
> |*0-<TensorType(float64, (?,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
287+
> |*1-<TensorType(float64, (?,))> [id CC] -> [id BO] (inner_in_non_seqs-0)"""
288288

289289
for truth, out in zip(expected_output.split("\n"), lines):
290290
assert truth.strip() == out.strip()
@@ -334,7 +334,7 @@ def compute_A_k(A, k):
334334
for{cpu,scan_fn} [id E] (outer_out_nit_sot-0)
335335
-*0-<TensorType(float64, ())> [id Y] -> [id U] (inner_in_seqs-0)
336336
-*1-<TensorType(int64, ())> [id Z] -> [id W] (inner_in_seqs-1)
337-
-*2-<TensorType(float64, (None,))> [id BA] -> [id C] (inner_in_non_seqs-0)
337+
-*2-<TensorType(float64, (?,))> [id BA] -> [id C] (inner_in_non_seqs-0)
338338
-*3-<TensorType(int32, ())> [id BB] -> [id B] (inner_in_non_seqs-1)
339339
>Elemwise{mul,no_inplace} [id BC] (inner_out_nit_sot-0)
340340
> |InplaceDimShuffle{x} [id BD]
@@ -353,7 +353,7 @@ def compute_A_k(A, k):
353353
> | | | | | | | |Unbroadcast{0} [id BN]
354354
> | | | | | | | |InplaceDimShuffle{x,0} [id BO]
355355
> | | | | | | | |Elemwise{second,no_inplace} [id BP]
356-
> | | | | | | | |*2-<TensorType(float64, (None,))> [id BA] (inner_in_non_seqs-0)
356+
> | | | | | | | |*2-<TensorType(float64, (?,))> [id BA] (inner_in_non_seqs-0)
357357
> | | | | | | | |InplaceDimShuffle{x} [id BQ]
358358
> | | | | | | | |TensorConstant{1.0} [id BR]
359359
> | | | | | | |ScalarConstant{0} [id BS]
@@ -364,18 +364,18 @@ def compute_A_k(A, k):
364364
> | | | | |Unbroadcast{0} [id BN]
365365
> | | | | |ScalarFromTensor [id BW]
366366
> | | | | |Subtensor{int64} [id BL]
367-
> | | | |*2-<TensorType(float64, (None,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
367+
> | | | |*2-<TensorType(float64, (?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
368368
> | | |ScalarConstant{1} [id BX]
369369
> | |ScalarConstant{-1} [id BY]
370370
> |InplaceDimShuffle{x} [id BZ]
371371
> |*1-<TensorType(int64, ())> [id Z] (inner_in_seqs-1)
372372
373373
for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0)
374-
-*0-<TensorType(float64, (None,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
375-
-*1-<TensorType(float64, (None,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
374+
-*0-<TensorType(float64, (?,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
375+
-*1-<TensorType(float64, (?,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
376376
>Elemwise{mul,no_inplace} [id CC] (inner_out_sit_sot-0)
377-
> |*0-<TensorType(float64, (None,))> [id CA] (inner_in_sit_sot-0)
378-
> |*1-<TensorType(float64, (None,))> [id CB] (inner_in_non_seqs-0)"""
377+
> |*0-<TensorType(float64, (?,))> [id CA] (inner_in_sit_sot-0)
378+
> |*1-<TensorType(float64, (?,))> [id CB] (inner_in_non_seqs-0)"""
379379

380380
for truth, out in zip(expected_output.split("\n"), lines):
381381
assert truth.strip() == out.strip()
@@ -413,7 +413,7 @@ def fn(a_m2, a_m1, b_m2, b_m1):
413413
| | | | |Subtensor{int64} [id H]
414414
| | | | |Shape [id I]
415415
| | | | | |Subtensor{:int64:} [id J]
416-
| | | | | |<TensorType(int64, (None,))> [id K]
416+
| | | | | |<TensorType(int64, (?,))> [id K]
417417
| | | | | |ScalarConstant{2} [id L]
418418
| | | | |ScalarConstant{0} [id M]
419419
| | | |Subtensor{:int64:} [id J]
@@ -426,7 +426,7 @@ def fn(a_m2, a_m1, b_m2, b_m1):
426426
| | | |Subtensor{int64} [id R]
427427
| | | |Shape [id S]
428428
| | | | |Subtensor{:int64:} [id T]
429-
| | | | |<TensorType(int64, (None,))> [id U]
429+
| | | | |<TensorType(int64, (?,))> [id U]
430430
| | | | |ScalarConstant{2} [id V]
431431
| | | |ScalarConstant{0} [id W]
432432
| | |Subtensor{:int64:} [id T]
@@ -562,19 +562,19 @@ def test_debugprint_mitmot():
562562
for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
563563
>Elemwise{add,no_inplace} [id CM] (inner_out_mit_mot-0-0)
564564
> |Elemwise{mul} [id CN]
565-
> | |*2-<TensorType(float64, (None,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
566-
> | |*5-<TensorType(float64, (None,))> [id CP] -> [id P] (inner_in_non_seqs-0)
567-
> |*3-<TensorType(float64, (None,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1)
565+
> | |*2-<TensorType(float64, (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
566+
> | |*5-<TensorType(float64, (?,))> [id CP] -> [id P] (inner_in_non_seqs-0)
567+
> |*3-<TensorType(float64, (?,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1)
568568
>Elemwise{add,no_inplace} [id CR] (inner_out_sit_sot-0)
569569
> |Elemwise{mul} [id CS]
570-
> | |*2-<TensorType(float64, (None,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
571-
> | |*0-<TensorType(float64, (None,))> [id CT] -> [id Z] (inner_in_seqs-0)
572-
> |*4-<TensorType(float64, (None,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
570+
> | |*2-<TensorType(float64, (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
571+
> | |*0-<TensorType(float64, (?,))> [id CT] -> [id Z] (inner_in_seqs-0)
572+
> |*4-<TensorType(float64, (?,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
573573
574574
for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
575575
>Elemwise{mul,no_inplace} [id CV] (inner_out_sit_sot-0)
576-
> |*0-<TensorType(float64, (None,))> [id CT] -> [id H] (inner_in_sit_sot-0)
577-
> |*1-<TensorType(float64, (None,))> [id CW] -> [id P] (inner_in_non_seqs-0)"""
576+
> |*0-<TensorType(float64, (?,))> [id CT] -> [id H] (inner_in_sit_sot-0)
577+
> |*1-<TensorType(float64, (?,))> [id CW] -> [id P] (inner_in_non_seqs-0)"""
578578

579579
for truth, out in zip(expected_output.split("\n"), lines):
580580
assert truth.strip() == out.strip()

tests/tensor/rewriting/test_basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,8 +614,8 @@ def test_eq(self):
614614
f2 = function([x], eq(x, x), mode=self.mode)
615615
assert np.all(f2(vx) == np.ones((5, 4)))
616616
topo2 = f2.maker.fgraph.toposort()
617-
# Shape_i{1}(<TensorType(float64, (None, None))>),
618-
# Shape_i{0}(<TensorType(float64, (None, None))>), Alloc([[1]], Shape_i{0}.0,
617+
# Shape_i{1}(<TensorType(float64, (?, ?))>),
618+
# Shape_i{0}(<TensorType(float64, (?, ?))>), Alloc([[1]], Shape_i{0}.0,
619619
# Shape_i{1}.0
620620
assert len(topo2) == 3
621621
assert isinstance(topo2[-1].op, Alloc)

tests/tensor/rewriting/test_elemwise.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def test_double_transpose(self):
8282
x, y, z = inputs()
8383
e = ds(ds(x, (1, 0)), (1, 0))
8484
g = FunctionGraph([x], [e])
85+
# TODO FIXME: Construct these graphs and compare them.
8586
assert (
8687
str(g) == "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))"
8788
)
@@ -93,6 +94,7 @@ def test_merge2(self):
9394
x, y, z = inputs()
9495
e = ds(ds(x, (1, "x", 0)), (2, 0, "x", 1))
9596
g = FunctionGraph([x], [e])
97+
# TODO FIXME: Construct these graphs and compare them.
9698
assert (
9799
str(g)
98100
== "FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))"
@@ -106,6 +108,7 @@ def test_elim3(self):
106108
x, y, z = inputs()
107109
e = ds(ds(ds(x, (0, "x", 1)), (2, 0, "x", 1)), (1, 0))
108110
g = FunctionGraph([x], [e])
111+
# TODO FIXME: Construct these graphs and compare them.
109112
assert str(g) == (
110113
"FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
111114
"(InplaceDimShuffle{0,x,1}(x))))"
@@ -119,6 +122,7 @@ def test_lift(self):
119122
e = x + y + z
120123
g = FunctionGraph([x, y, z], [e])
121124

125+
# TODO FIXME: Construct these graphs and compare them.
122126
# It does not really matter if the DimShuffles are inplace
123127
# or not.
124128
init_str_g_inplace = (
@@ -149,24 +153,25 @@ def test_recursive_lift(self):
149153
m = matrix(dtype="float64")
150154
out = ((v + 42) * (m + 84)).T
151155
g = FunctionGraph([v, m], [out])
156+
# TODO FIXME: Construct these graphs and compare them.
152157
init_str_g = (
153158
"FunctionGraph(InplaceDimShuffle{1,0}(Elemwise{mul,no_inplace}"
154159
"(InplaceDimShuffle{x,0}(Elemwise{add,no_inplace}"
155-
"(<TensorType(float64, (None,))>, "
160+
"(<TensorType(float64, (?,))>, "
156161
"InplaceDimShuffle{x}(TensorConstant{42}))), "
157162
"Elemwise{add,no_inplace}"
158-
"(<TensorType(float64, (None, None))>, "
163+
"(<TensorType(float64, (?, ?))>, "
159164
"InplaceDimShuffle{x,x}(TensorConstant{84})))))"
160165
)
161166
assert str(g) == init_str_g
162167
new_out = local_dimshuffle_lift.transform(g, g.outputs[0].owner)[0]
163168
new_g = FunctionGraph(g.inputs, [new_out])
164169
rewrite_str_g = (
165170
"FunctionGraph(Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
166-
"(InplaceDimShuffle{0,x}(<TensorType(float64, (None,))>), "
171+
"(InplaceDimShuffle{0,x}(<TensorType(float64, (?,))>), "
167172
"InplaceDimShuffle{x,x}(TensorConstant{42})), "
168173
"Elemwise{add,no_inplace}(InplaceDimShuffle{1,0}"
169-
"(<TensorType(float64, (None, None))>), "
174+
"(<TensorType(float64, (?, ?))>), "
170175
"InplaceDimShuffle{x,x}(TensorConstant{84}))))"
171176
)
172177
assert str(new_g) == rewrite_str_g
@@ -177,6 +182,7 @@ def test_useless_dimshuffle(self):
177182
x, _, _ = inputs()
178183
e = ds(x, (0, 1))
179184
g = FunctionGraph([x], [e])
185+
# TODO FIXME: Construct these graphs and compare them.
180186
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1}(x))"
181187
dimshuffle_lift.rewrite(g)
182188
assert str(g) == "FunctionGraph(x)"
@@ -191,6 +197,7 @@ def test_dimshuffle_on_broadcastable(self):
191197
ds_z = ds(z, (2, 1, 0)) # useful
192198
ds_u = ds(u, ("x")) # useful
193199
g = FunctionGraph([x, y, z, u], [ds_x, ds_y, ds_z, ds_u])
200+
# TODO FIXME: Construct these graphs and compare them.
194201
assert (
195202
str(g)
196203
== "FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
@@ -225,6 +232,7 @@ def test_local_useless_dimshuffle_in_reshape():
225232
],
226233
)
227234

235+
# TODO FIXME: Construct these graphs and compare them.
228236
assert str(g) == (
229237
"FunctionGraph(Reshape{1}(InplaceDimShuffle{x,0}(vector), Shape(vector)), "
230238
"Reshape{2}(InplaceDimShuffle{x,0,x,1}(mat), Shape(mat)), "

tests/tensor/test_sharedvar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ def test_specify_shape_inplace(self):
513513
)
514514
topo = f.maker.fgraph.toposort()
515515
f()
516-
# [Gemm{inplace}(<TensorType(float64, (None, None))>, 0.01, <TensorType(float64, (None, None))>, <TensorType(float64, (None, None))>, 2e-06)]
516+
# [Gemm{inplace}(<TensorType(float64, (?, ?))>, 0.01, <TensorType(float64, (?, ?))>, <TensorType(float64, (?, ?))>, 2e-06)]
517517
if aesara.config.mode != "FAST_COMPILE":
518518
assert (
519519
sum(

0 commit comments

Comments
 (0)