|
79 | 79 | import logging
|
80 | 80 | import os
|
81 | 81 | import shlex
|
82 |
| -import time |
83 | 82 | from pathlib import Path
|
84 | 83 |
|
85 | 84 | import numpy as np
|
|
103 | 102 | from pytensor.tensor import basic as ptb
|
104 | 103 | from pytensor.tensor.basic import expand_dims
|
105 | 104 | from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
|
106 |
| -from pytensor.tensor.elemwise import DimShuffle |
107 |
| -from pytensor.tensor.math import add, mul, neg, sub, variadic_add |
108 | 105 | from pytensor.tensor.shape import shape_padright, specify_broadcastable
|
109 |
| -from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor |
| 106 | +from pytensor.tensor.type import DenseTensorType, tensor |
110 | 107 |
|
111 | 108 |
|
112 | 109 | _logger = logging.getLogger("pytensor.tensor.blas")
|
@@ -1148,322 +1145,6 @@ def c_code_cache_version(self):
|
1148 | 1145 | pprint.assign(gemm_no_inplace, FunctionPrinter(["gemm_no_inplace"]))
|
1149 | 1146 |
|
1150 | 1147 |
|
1151 |
| -def res_is_a(fgraph, var, op, maxclients=None): |
1152 |
| - if maxclients is not None and var in fgraph.clients: |
1153 |
| - retval = len(fgraph.get_clients(var)) <= maxclients |
1154 |
| - else: |
1155 |
| - retval = True |
1156 |
| - |
1157 |
| - return var.owner and var.owner.op == op and retval |
1158 |
| - |
1159 |
| - |
1160 |
| -def _as_scalar(res, dtype=None): |
1161 |
| - """Return ``None`` or a `TensorVariable` of float type""" |
1162 |
| - if dtype is None: |
1163 |
| - dtype = config.floatX |
1164 |
| - if all(s == 1 for s in res.type.shape): |
1165 |
| - while res.owner and isinstance(res.owner.op, DimShuffle): |
1166 |
| - res = res.owner.inputs[0] |
1167 |
| - # may still have some number of True's |
1168 |
| - if res.type.ndim > 0: |
1169 |
| - rval = res.dimshuffle() |
1170 |
| - else: |
1171 |
| - rval = res |
1172 |
| - if rval.type.dtype in integer_dtypes: |
1173 |
| - # We check that the upcast of res and dtype won't change dtype. |
1174 |
| - # If dtype is float64, we will cast int64 to float64. |
1175 |
| - # This is valid when res is a scalar used as input to a dot22 |
1176 |
| - # as the cast of the scalar can be done before or after the dot22 |
1177 |
| - # and this will give the same result. |
1178 |
| - if pytensor.scalar.upcast(res.dtype, dtype) == dtype: |
1179 |
| - return ptb.cast(rval, dtype) |
1180 |
| - else: |
1181 |
| - return None |
1182 |
| - |
1183 |
| - return rval |
1184 |
| - |
1185 |
| - |
1186 |
| -def _is_real_matrix(res): |
1187 |
| - return ( |
1188 |
| - res.type.dtype in ("float16", "float32", "float64") |
1189 |
| - and res.type.ndim == 2 |
1190 |
| - and res.type.shape[0] != 1 |
1191 |
| - and res.type.shape[1] != 1 |
1192 |
| - ) # cope with tuple vs. list |
1193 |
| - |
1194 |
| - |
1195 |
| -def _is_real_vector(res): |
1196 |
| - return ( |
1197 |
| - res.type.dtype in ("float16", "float32", "float64") |
1198 |
| - and res.type.ndim == 1 |
1199 |
| - and res.type.shape[0] != 1 |
1200 |
| - ) |
1201 |
| - |
1202 |
| - |
1203 |
| -def _beta_L_plus_alpha_M(fgraph, beta, L, alpha, M, recurse_flip=True): |
1204 |
| - # print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip |
1205 |
| - # EXPRESSION: (beta * L) + (alpha * M) |
1206 |
| - |
1207 |
| - # we've already checked the client counts, now just make the type check. |
1208 |
| - # if res_is_a(M, _dot22, 1): |
1209 |
| - if M.owner and M.owner.op == _dot22: |
1210 |
| - Ml, Mr = M.owner.inputs |
1211 |
| - rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)] |
1212 |
| - return rval, M |
1213 |
| - |
1214 |
| - # it also might be the case that there is a dimshuffle between the + |
1215 |
| - # and the dot22. local_dot_to_dot22 in particular will put in such things. |
1216 |
| - if ( |
1217 |
| - M.owner |
1218 |
| - and isinstance(M.owner.op, DimShuffle) |
1219 |
| - and M.owner.inputs[0].owner |
1220 |
| - and isinstance(M.owner.inputs[0].owner.op, Dot22) |
1221 |
| - ): |
1222 |
| - MM = M.owner.inputs[0] |
1223 |
| - if M.owner.op.new_order == (0,): |
1224 |
| - # it is making a column MM into a vector |
1225 |
| - MMl, MMr = MM.owner.inputs |
1226 |
| - g = gemm_no_inplace(L.dimshuffle(0, "x"), alpha, MMl, MMr, beta) |
1227 |
| - rval = [g.dimshuffle(0)] |
1228 |
| - return rval, MM |
1229 |
| - if M.owner.op.new_order == (1,): |
1230 |
| - # it is making a row MM into a vector |
1231 |
| - MMl, MMr = MM.owner.inputs |
1232 |
| - g = gemm_no_inplace(L.dimshuffle("x", 0), alpha, MMl, MMr, beta) |
1233 |
| - rval = [g.dimshuffle(1)] |
1234 |
| - return rval, MM |
1235 |
| - if len(M.owner.op.new_order) == 0: |
1236 |
| - # it is making a row MM into a vector |
1237 |
| - MMl, MMr = MM.owner.inputs |
1238 |
| - g = gemm_no_inplace(L.dimshuffle("x", "x"), alpha, MMl, MMr, beta) |
1239 |
| - rval = [g.dimshuffle()] |
1240 |
| - return rval, MM |
1241 |
| - |
1242 |
| - if recurse_flip: |
1243 |
| - return _beta_L_plus_alpha_M(fgraph, alpha, M, beta, L, recurse_flip=False) |
1244 |
| - else: |
1245 |
| - return False, False |
1246 |
| - |
1247 |
| - |
1248 |
| -def _gemm_canonicalize(fgraph, r, scale, rval, maxclients): |
1249 |
| - # Tries to interpret node as a sum of scalars * (vectors or matrices) |
1250 |
| - def scaled(thing): |
1251 |
| - if scale == 1: |
1252 |
| - return thing |
1253 |
| - if scale == -1 and thing.type.dtype != "bool": |
1254 |
| - return -thing |
1255 |
| - else: |
1256 |
| - return scale * thing |
1257 |
| - |
1258 |
| - if not isinstance(r.type, TensorType): |
1259 |
| - return None |
1260 |
| - |
1261 |
| - if (r.type.ndim not in (1, 2)) or r.type.dtype not in ( |
1262 |
| - "float16", |
1263 |
| - "float32", |
1264 |
| - "float64", |
1265 |
| - "complex64", |
1266 |
| - "complex128", |
1267 |
| - ): |
1268 |
| - rval.append(scaled(r)) |
1269 |
| - return rval |
1270 |
| - |
1271 |
| - if maxclients and len(fgraph.clients[r]) > maxclients: |
1272 |
| - rval.append((scale, r)) |
1273 |
| - return rval |
1274 |
| - |
1275 |
| - if r.owner and r.owner.op == sub: |
1276 |
| - _gemm_canonicalize(fgraph, r.owner.inputs[0], scale, rval, 1) |
1277 |
| - _gemm_canonicalize(fgraph, r.owner.inputs[1], -scale, rval, 1) |
1278 |
| - |
1279 |
| - elif r.owner and r.owner.op == add: |
1280 |
| - for i in r.owner.inputs: |
1281 |
| - _gemm_canonicalize(fgraph, i, scale, rval, 1) |
1282 |
| - |
1283 |
| - elif r.owner and r.owner.op == neg: |
1284 |
| - _gemm_canonicalize(fgraph, r.owner.inputs[0], -scale, rval, 1) |
1285 |
| - |
1286 |
| - elif r.owner and r.owner.op == mul: |
1287 |
| - scalars = [] |
1288 |
| - vectors = [] |
1289 |
| - matrices = [] |
1290 |
| - for i in r.owner.inputs: |
1291 |
| - if all(s == 1 for s in i.type.shape): |
1292 |
| - while i.owner and isinstance(i.owner.op, DimShuffle): |
1293 |
| - i = i.owner.inputs[0] |
1294 |
| - if i.type.ndim > 0: |
1295 |
| - scalars.append(i.dimshuffle()) |
1296 |
| - else: |
1297 |
| - scalars.append(i) |
1298 |
| - elif _is_real_vector(i): |
1299 |
| - vectors.append(i) |
1300 |
| - elif _is_real_matrix(i): |
1301 |
| - matrices.append(i) |
1302 |
| - else: |
1303 |
| - # just put the original arguments as in the base case |
1304 |
| - rval.append((scale, r)) |
1305 |
| - return rval |
1306 |
| - if len(matrices) == 1: |
1307 |
| - assert len(vectors) == 0 |
1308 |
| - m = matrices[0] |
1309 |
| - if len(scalars) == 0: |
1310 |
| - _gemm_canonicalize(fgraph, m, scale, rval, 1) |
1311 |
| - elif len(scalars) == 1: |
1312 |
| - _gemm_canonicalize(fgraph, m, scaled(scalars[0]), rval, 1) |
1313 |
| - else: |
1314 |
| - _gemm_canonicalize( |
1315 |
| - fgraph, m, mul(scaled(scalars[0]), *scalars[1:]), rval, 1 |
1316 |
| - ) |
1317 |
| - elif len(vectors) == 1: |
1318 |
| - assert len(matrices) == 0 |
1319 |
| - v = vectors[0] |
1320 |
| - if len(scalars) == 0: |
1321 |
| - _gemm_canonicalize(fgraph, v, scale, rval, 1) |
1322 |
| - elif len(scalars) == 1: |
1323 |
| - _gemm_canonicalize(fgraph, v, scaled(scalars[0]), rval, 1) |
1324 |
| - else: |
1325 |
| - _gemm_canonicalize( |
1326 |
| - fgraph, v, mul(scaled(scalars[0]), *scalars[1:]), rval, 1 |
1327 |
| - ) |
1328 |
| - else: # lets not open this up |
1329 |
| - rval.append((scale, r)) |
1330 |
| - else: |
1331 |
| - rval.append((scale, r)) |
1332 |
| - return rval |
1333 |
| - |
1334 |
| - |
1335 |
| -def _factor_canonicalized(lst): |
1336 |
| - # remove duplicates from canonicalized list |
1337 |
| - |
1338 |
| - # we only delete out of the right end of the list, |
1339 |
| - # once i has touched a list element, it is permantent |
1340 |
| - lst = list(lst) |
1341 |
| - # print 'FACTOR', lst |
1342 |
| - # for t in lst: |
1343 |
| - # if not isinstance(t, (list, tuple)): |
1344 |
| - # t = (t,) |
1345 |
| - # for e in t: |
1346 |
| - # try: |
1347 |
| - # pytensor.printing.debugprint(e) |
1348 |
| - # except TypeError: |
1349 |
| - # print e, type(e) |
1350 |
| - i = 0 |
1351 |
| - while i < len(lst) - 1: |
1352 |
| - try: |
1353 |
| - s_i, M_i = lst[i] |
1354 |
| - except Exception: |
1355 |
| - i += 1 |
1356 |
| - continue |
1357 |
| - |
1358 |
| - j = i + 1 |
1359 |
| - while j < len(lst): |
1360 |
| - try: |
1361 |
| - s_j, M_j = lst[j] |
1362 |
| - except Exception: |
1363 |
| - j += 1 |
1364 |
| - continue |
1365 |
| - |
1366 |
| - if M_i is M_j: |
1367 |
| - s_i = s_i + s_j |
1368 |
| - lst[i] = (s_i, M_i) |
1369 |
| - del lst[j] |
1370 |
| - else: |
1371 |
| - j += 1 |
1372 |
| - i += 1 |
1373 |
| - return lst |
1374 |
| - |
1375 |
| - |
1376 |
| -def _gemm_from_factored_list(fgraph, lst): |
1377 |
| - """ |
1378 |
| - Returns None, or a list to replace node.outputs. |
1379 |
| -
|
1380 |
| - """ |
1381 |
| - lst2 = [] |
1382 |
| - # Remove the tuple that can't be cast correctly. |
1383 |
| - # This can happen when we try to cast a complex to a real |
1384 |
| - for sM in lst: |
1385 |
| - # Make every pair in list have matching dtypes |
1386 |
| - # sM can be a tuple of 2 elements or an PyTensor variable. |
1387 |
| - if isinstance(sM, tuple): |
1388 |
| - sm0, sm1 = sM |
1389 |
| - sm0 = ptb.as_tensor_variable(sm0) |
1390 |
| - if pytensor.scalar.upcast(sm0.dtype, sm1.dtype) == sm1.dtype: |
1391 |
| - lst2.append((ptb.cast(sm0, sm1.dtype), sM[1])) |
1392 |
| - |
1393 |
| - lst = lst2 |
1394 |
| - |
1395 |
| - def item_to_var(t): |
1396 |
| - try: |
1397 |
| - s, M = t |
1398 |
| - except Exception: |
1399 |
| - return t |
1400 |
| - if s == 1: |
1401 |
| - return M |
1402 |
| - if s == -1: |
1403 |
| - return -M |
1404 |
| - return s * M |
1405 |
| - |
1406 |
| - # Try every pair in the sM_list, trying to turn it into a gemm operation |
1407 |
| - for i in range(len(lst) - 1): |
1408 |
| - s_i, M_i = lst[i] |
1409 |
| - |
1410 |
| - for j in range(i + 1, len(lst)): |
1411 |
| - s_j, M_j = lst[j] |
1412 |
| - |
1413 |
| - if not M_j.type.in_same_class(M_i.type): |
1414 |
| - continue |
1415 |
| - |
1416 |
| - # print 'TRYING', (s_i, M_i, s_j, M_j) |
1417 |
| - |
1418 |
| - gemm_of_sM_list, old_dot22 = _beta_L_plus_alpha_M( |
1419 |
| - fgraph, s_i, M_i, s_j, M_j |
1420 |
| - ) |
1421 |
| - # print 'GOT IT', gemm_of_sM_list |
1422 |
| - if gemm_of_sM_list: |
1423 |
| - assert len(gemm_of_sM_list) == 1 |
1424 |
| - add_inputs = [ |
1425 |
| - item_to_var(input) for k, input in enumerate(lst) if k not in (i, j) |
1426 |
| - ] |
1427 |
| - add_inputs.extend(gemm_of_sM_list) |
1428 |
| - rval = [variadic_add(*add_inputs)] |
1429 |
| - return rval, old_dot22 |
1430 |
| - |
1431 |
| - |
1432 |
| -def _gemm_from_node2(fgraph, node): |
1433 |
| - """ |
1434 |
| -
|
1435 |
| - TODO: In many expressions, there are many ways to turn it into a |
1436 |
| - gemm. For example dot(a,b) + c + d. This function should return all |
1437 |
| - of them, so that if one version of gemm causes a cycle in the graph, then |
1438 |
| - another application of gemm can be tried. |
1439 |
| -
|
1440 |
| - """ |
1441 |
| - lst = [] |
1442 |
| - t0 = time.perf_counter() |
1443 |
| - _gemm_canonicalize(fgraph, node.outputs[0], 1.0, lst, 0) |
1444 |
| - t1 = time.perf_counter() |
1445 |
| - |
1446 |
| - if len(lst) > 1: |
1447 |
| - lst = _factor_canonicalized(lst) |
1448 |
| - t2 = time.perf_counter() |
1449 |
| - rval = _gemm_from_factored_list(fgraph, lst) |
1450 |
| - t3 = time.perf_counter() |
1451 |
| - |
1452 |
| - # It can happen that _factor_canonicalized and |
1453 |
| - # _gemm_from_factored_list return a node with an incorrect |
1454 |
| - # type. This happens in particular when one of the scalar |
1455 |
| - # factors forces the upcast of the whole expression. In that |
1456 |
| - # case, we simply skip that candidate for Gemm. This was |
1457 |
| - # discussed in |
1458 |
| - # http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5, |
1459 |
| - # but never made it into a trac ticket. |
1460 |
| - |
1461 |
| - if rval and rval[0][0].type.in_same_class(node.outputs[0].type): |
1462 |
| - return rval, t1 - t0, t2 - t1, t3 - t2 |
1463 |
| - |
1464 |
| - return None, t1 - t0, 0, 0 |
1465 |
| - |
1466 |
| - |
1467 | 1148 | class Dot22(GemmRelated):
|
1468 | 1149 | """Compute a matrix-matrix product.
|
1469 | 1150 |
|
|
0 commit comments