@@ -1153,20 +1153,44 @@ def test_tree_all(self):
1153
1153
tree_util .tree_all (obj ),
1154
1154
)
1155
1155
1156
+ def test_tree_all_is_leaf (self ):
1157
+ obj = [True , True , (True , False )]
1158
+ is_leaf = lambda x : isinstance (x , tuple )
1159
+ self .assertEqual (
1160
+ jax .tree .all (obj , is_leaf = is_leaf ),
1161
+ tree_util .tree_all (obj , is_leaf = is_leaf ),
1162
+ )
1163
+
1156
1164
def test_tree_flatten (self ):
1157
1165
obj = [1 , 2 , (3 , 4 )]
1158
1166
self .assertEqual (
1159
1167
jax .tree .flatten (obj ),
1160
1168
tree_util .tree_flatten (obj ),
1161
1169
)
1162
1170
1171
+ def test_tree_flatten_is_leaf (self ):
1172
+ obj = [1 , 2 , (3 , 4 )]
1173
+ is_leaf = lambda x : isinstance (x , tuple )
1174
+ self .assertEqual (
1175
+ jax .tree .flatten (obj , is_leaf = is_leaf ),
1176
+ tree_util .tree_flatten (obj , is_leaf = is_leaf ),
1177
+ )
1178
+
1163
1179
def test_tree_leaves (self ):
1164
1180
obj = [1 , 2 , (3 , 4 )]
1165
1181
self .assertEqual (
1166
1182
jax .tree .leaves (obj ),
1167
1183
tree_util .tree_leaves (obj ),
1168
1184
)
1169
1185
1186
+ def test_tree_leaves_is_leaf (self ):
1187
+ obj = [1 , 2 , (3 , 4 )]
1188
+ is_leaf = lambda x : isinstance (x , tuple )
1189
+ self .assertEqual (
1190
+ jax .tree .leaves (obj , is_leaf = is_leaf ),
1191
+ tree_util .tree_leaves (obj , is_leaf = is_leaf ),
1192
+ )
1193
+
1170
1194
def test_tree_map (self ):
1171
1195
func = lambda x : x * 2
1172
1196
obj = [1 , 2 , (3 , 4 )]
@@ -1175,6 +1199,15 @@ def test_tree_map(self):
1175
1199
tree_util .tree_map (func , obj ),
1176
1200
)
1177
1201
1202
+ def test_tree_map_is_leaf (self ):
1203
+ func = lambda x : x * 2
1204
+ obj = [1 , 2 , (3 , 4 )]
1205
+ is_leaf = lambda x : isinstance (x , tuple )
1206
+ self .assertEqual (
1207
+ jax .tree .map (func , obj , is_leaf = is_leaf ),
1208
+ tree_util .tree_map (func , obj , is_leaf = is_leaf ),
1209
+ )
1210
+
1178
1211
def test_tree_reduce (self ):
1179
1212
func = lambda a , b : a + b
1180
1213
obj = [1 , 2 , (3 , 4 )]
@@ -1183,13 +1216,30 @@ def test_tree_reduce(self):
1183
1216
tree_util .tree_reduce (func , obj ),
1184
1217
)
1185
1218
1219
+ def test_tree_reduce_is_leaf (self ):
1220
+ func = lambda a , b : a + b
1221
+ obj = [(1 , 2 ), (3 , 4 )]
1222
+ is_leaf = lambda x : isinstance (x , tuple )
1223
+ self .assertEqual (
1224
+ jax .tree .reduce (func , obj , is_leaf = is_leaf ),
1225
+ tree_util .tree_reduce (func , obj , is_leaf = is_leaf ),
1226
+ )
1227
+
1186
1228
def test_tree_structure (self ):
1187
1229
obj = [1 , 2 , (3 , 4 )]
1188
1230
self .assertEqual (
1189
1231
jax .tree .structure (obj ),
1190
1232
tree_util .tree_structure (obj ),
1191
1233
)
1192
1234
1235
+ def test_tree_structure_is_leaf (self ):
1236
+ obj = [1 , 2 , (3 , 4 )]
1237
+ is_leaf = lambda x : isinstance (x , tuple )
1238
+ self .assertEqual (
1239
+ jax .tree .structure (obj , is_leaf = is_leaf ),
1240
+ tree_util .tree_structure (obj , is_leaf = is_leaf ),
1241
+ )
1242
+
1193
1243
def test_tree_transpose (self ):
1194
1244
obj = [(1 , 2 ), (3 , 4 ), (5 , 6 )]
1195
1245
outer_treedef = tree_util .tree_structure (['*' , '*' , '*' ])
0 commit comments