2
2
import pytest
3
3
4
4
from pymc .distributions import CustomDist
5
+ from pytensor .tensor .type_other import NoneTypeT
5
6
6
- from pymc_experimental .model .marginal .graph_analysis import subgraph_dim_connection
7
+ from pymc_experimental .model .marginal .graph_analysis import subgraph_batch_dim_connection
7
8
8
9
9
- class TestSubgraphDimConnection :
10
+ class TestSubgraphBatchDimConnection :
10
11
def test_dimshuffle (self ):
11
12
inp = pt .tensor (shape = (5 , 1 , 4 , 3 ))
12
13
out1 = pt .matrix_transpose (inp )
13
14
out2 = pt .expand_dims (inp , 1 )
14
15
out3 = pt .squeeze (inp )
15
- [dims1 , dims2 , dims3 ] = subgraph_dim_connection (inp , [], [out1 , out2 , out3 ])
16
+ [dims1 , dims2 , dims3 ] = subgraph_batch_dim_connection (inp , [], [out1 , out2 , out3 ])
16
17
assert dims1 == ((0 ,), (1 ,), (3 ,), (2 ,))
17
18
assert dims2 == ((0 ,), (), (1 ,), (2 ,), (3 ,))
18
19
assert dims3 == ((0 ,), (2 ,), (3 ,))
19
20
20
21
def test_careduce (self ):
21
22
inp = pt .tensor (shape = (4 , 3 , 2 ))
22
- out = pt .sum (inp , axis = (1 ,))
23
- [dims ] = subgraph_dim_connection (inp , [], [out ])
24
- assert dims == ((0 , 1 ), (2 , 1 ))
23
+
24
+ out = pt .sum (inp [:, None ], axis = (1 ,))
25
+ [dims ] = subgraph_batch_dim_connection (inp , [], [out ])
26
+ assert dims == ((0 ,), (1 ,), (2 ,))
27
+
28
+ invalid_out = pt .sum (inp , axis = (1 ,))
29
+ with pytest .raises (ValueError , match = "Use of known dimensions" ):
30
+ subgraph_batch_dim_connection (inp , [], [invalid_out ])
25
31
26
32
def test_subtensor (self ):
27
33
inp = pt .tensor (shape = (4 , 3 , 2 ))
28
34
29
35
invalid_out = inp [0 , :1 ]
30
36
with pytest .raises (
31
- NotImplementedError ,
37
+ ValueError ,
32
38
match = "Partial slicing or indexing of known dimensions not supported" ,
33
39
):
34
- subgraph_dim_connection (inp , [], [invalid_out ])
40
+ subgraph_batch_dim_connection (inp , [], [invalid_out ])
35
41
36
42
# If we are selecting dummy / unknown dimensions that's fine
37
43
valid_out = pt .expand_dims (inp , (0 , 1 ))[0 , :1 ]
38
- [dims ] = subgraph_dim_connection (inp , [], [valid_out ])
44
+ [dims ] = subgraph_batch_dim_connection (inp , [], [valid_out ])
39
45
assert dims == ((), (0 ,), (1 ,), (2 ,))
40
46
41
47
def test_advanced_subtensor_value (self ):
@@ -44,99 +50,116 @@ def test_advanced_subtensor_value(self):
44
50
45
51
# Index on an unlabled dim introduced by broadcasting with zeros
46
52
out = intermediate_out [:, [0 , 0 , 1 , 2 ]]
47
- [dims ] = subgraph_dim_connection (inp , [], [out ])
53
+ [dims ] = subgraph_batch_dim_connection (inp , [], [out ])
48
54
assert dims == ((0 ,), (), (1 ,), ())
49
55
50
56
# Indexing that introduces more dimensions
51
57
out = intermediate_out [:, [[0 , 0 ], [1 , 2 ]], :]
52
- [dims ] = subgraph_dim_connection (inp , [], [out ])
58
+ [dims ] = subgraph_batch_dim_connection (inp , [], [out ])
53
59
assert dims == ((0 ,), (), (), (1 ,), ())
54
60
55
61
# Special case where advanced dims are moved to the front of the output
56
62
out = intermediate_out [:, [0 , 0 , 1 , 2 ], :, 0 ]
57
- [dims ] = subgraph_dim_connection (inp , [], [out ])
63
+ [dims ] = subgraph_batch_dim_connection (inp , [], [out ])
58
64
assert dims == ((), (0 ,), (1 ,))
59
65
60
66
# Indexing on a labeled dim fails
61
67
out = intermediate_out [:, :, [0 , 0 , 1 , 2 ]]
62
- with pytest .raises (
63
- NotImplementedError , match = "Partial slicing or advanced integer indexing"
64
- ):
65
- subgraph_dim_connection (inp , [], [out ])
68
+ with pytest .raises (ValueError , match = "Partial slicing or advanced integer indexing" ):
69
+ subgraph_batch_dim_connection (inp , [], [out ])
66
70
67
71
def test_advanced_subtensor_key (self ):
68
72
inp = pt .tensor (shape = (5 , 5 ), dtype = int )
69
73
base = pt .zeros ((2 , 3 , 4 ))
70
74
71
75
out = base [inp ]
72
- [dims ] = subgraph_dim_connection (inp , [], [out ])
76
+ [dims ] = subgraph_batch_dim_connection (inp , [], [out ])
73
77
assert dims == ((0 ,), (1 ,), (), ())
74
78
75
79
out = base [:, :, inp ]
76
- [dims ] = subgraph_dim_connection (inp , [], [out ])
80
+ [dims ] = subgraph_batch_dim_connection (inp , [], [out ])
77
81
assert dims == ((), (), (0 ,), (1 ,))
78
82
79
83
out = base [1 :, 0 , inp ]
80
- [dims ] = subgraph_dim_connection (inp , [], [out ])
84
+ [dims ] = subgraph_batch_dim_connection (inp , [], [out ])
81
85
assert dims == ((), (0 ,), (1 ,))
82
86
83
87
# Special case where advanced dims are moved to the front of the output
84
88
out = base [0 , :, inp ]
85
- [dims ] = subgraph_dim_connection (inp , [], [out ])
89
+ [dims ] = subgraph_batch_dim_connection (inp , [], [out ])
86
90
assert dims == ((0 ,), (1 ,), ())
87
91
88
92
# Mix keys dimensions
89
93
out = base [:, inp , inp .T ]
90
- [dims ] = subgraph_dim_connection (inp , [], [out ])
94
+ [dims ] = subgraph_batch_dim_connection (inp , [], [out ])
91
95
assert dims == ((), (0 , 1 ), (0 , 1 ))
92
96
93
97
def test_elemwise (self ):
94
98
inp = pt .tensor (shape = (5 , 5 ))
95
99
96
100
out = inp + inp
97
- [dims ] = subgraph_dim_connection (inp , [], [out ])
101
+ [dims ] = subgraph_batch_dim_connection (inp , [], [out ])
98
102
assert dims == ((0 ,), (1 ,))
99
103
100
104
out = inp + inp .T
101
- [dims ] = subgraph_dim_connection (inp , [], [out ])
105
+ [dims ] = subgraph_batch_dim_connection (inp , [], [out ])
102
106
assert dims == (
103
107
(0 , 1 ),
104
- (
105
- 0 ,
106
- 1 ,
107
- ),
108
+ (0 , 1 ),
108
109
)
109
110
110
111
def test_blockwise (self ):
111
- inp = pt .tensor (shape = (5 , 4 , 3 , 2 ))
112
- out = inp @ pt .ones ((2 , 3 ))
113
- [dims ] = subgraph_dim_connection (inp , [], [out ])
114
- # Every dimension contains information from the core dimensions
115
- assert dims == ((0 , 2 , 3 ), (1 , 2 , 3 ), (2 , 3 ), (2 , 3 ))
112
+ inp = pt .tensor (shape = (5 , 4 ))
113
+
114
+ invalid_out = inp @ pt .ones ((4 , 3 ))
115
+ with pytest .raises (ValueError , match = "Use of known dimensions" ):
116
+ subgraph_batch_dim_connection (inp , [], [invalid_out ])
117
+
118
+ out = (inp [:, :, None , None ] + pt .zeros ((2 , 3 ))) @ pt .ones ((2 , 3 ))
119
+ [dims ] = subgraph_batch_dim_connection (inp , [], [out ])
120
+ assert dims == ((0 ,), (1 ,), (), ())
116
121
117
122
def test_random_variable (self ):
118
123
inp = pt .tensor (shape = (5 , 4 , 3 ))
124
+
119
125
out1 = pt .random .normal (loc = inp )
120
- out2 = pt .random .categorical (p = inp )
121
- out3 = pt .random .multivariate_normal (mean = inp , cov = pt .eye (3 ))
122
- [dims1 , dims2 , dims3 ] = subgraph_dim_connection (inp , [], [out1 , out2 , out3 ])
126
+ out2 = pt .random .categorical (p = inp [..., None ] )
127
+ out3 = pt .random .multivariate_normal (mean = inp [..., None ], cov = pt .eye (1 ))
128
+ [dims1 , dims2 , dims3 ] = subgraph_batch_dim_connection (inp , [], [out1 , out2 , out3 ])
123
129
assert dims1 == ((0 ,), (1 ,), (2 ,))
124
- assert dims2 == ((0 , 2 ), (1 , 2 ))
125
- assert dims3 == ((0 , 2 ), (1 , 2 ), (2 ,))
130
+ assert dims2 == ((0 ,), (1 ,), (2 ,))
131
+ assert dims3 == ((0 ,), (1 ,), (2 ,), ())
132
+
133
+ invalid_out = pt .random .categorical (p = inp )
134
+ with pytest .raises (ValueError , match = "Use of known dimensions" ):
135
+ subgraph_batch_dim_connection (inp , [], [invalid_out ])
136
+
137
+ invalid_out = pt .random .multivariate_normal (mean = inp , cov = pt .eye (3 ))
138
+ with pytest .raises (ValueError , match = "Use of known dimensions" ):
139
+ subgraph_batch_dim_connection (inp , [], [invalid_out ])
126
140
127
141
def test_symbolic_random_variable (self ):
128
142
inp = pt .tensor (shape = (4 , 3 , 2 ))
143
+
144
+ # Test univariate
129
145
out = CustomDist .dist (
130
146
inp ,
131
147
dist = lambda mu , size : pt .random .normal (loc = mu , size = size ),
132
148
)
133
- [dims ] = subgraph_dim_connection (inp , [], [out ])
149
+ [dims ] = subgraph_batch_dim_connection (inp , [], [out ])
134
150
assert dims == ((0 ,), (1 ,), (2 ,))
135
151
136
152
# Test multivariate
153
+ def dist (mu , size ):
154
+ if isinstance (size .type , NoneTypeT ):
155
+ size = mu .shape
156
+ return pt .random .normal (loc = mu [..., None ], size = (* size , 2 )).sum (- 1 )
157
+
137
158
out = CustomDist .dist (
138
159
inp ,
139
- dist = lambda mu , size : pt .random .normal (loc = mu , size = size ).sum (- 1 ),
160
+ dist = dist ,
161
+ size = (4 , 3 , 2 ),
162
+ ndim_supp = 1 ,
140
163
)
141
- [dims ] = subgraph_dim_connection (inp , [], [out ])
142
- assert dims == ((0 , 2 ), (1 , 2 ))
164
+ [dims ] = subgraph_batch_dim_connection (inp , [], [out ])
165
+ assert dims == ((0 ,), (1 ,), ( 2 , ))
0 commit comments