|
79 | 79 | import logging |
80 | 80 | import os |
81 | 81 | import shlex |
82 | | -import time |
83 | 82 | from pathlib import Path |
84 | 83 |
|
85 | 84 | import numpy as np |
|
103 | 102 | from pytensor.tensor import basic as ptb |
104 | 103 | from pytensor.tensor.basic import expand_dims |
105 | 104 | from pytensor.tensor.blas_headers import blas_header_text, blas_header_version |
106 | | -from pytensor.tensor.elemwise import DimShuffle |
107 | | -from pytensor.tensor.math import add, mul, neg, sub, variadic_add |
108 | 105 | from pytensor.tensor.shape import shape_padright, specify_broadcastable |
109 | | -from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor |
| 106 | +from pytensor.tensor.type import DenseTensorType, tensor |
110 | 107 |
|
111 | 108 |
|
112 | 109 | _logger = logging.getLogger("pytensor.tensor.blas") |
@@ -1148,322 +1145,6 @@ def c_code_cache_version(self): |
1148 | 1145 | pprint.assign(gemm_no_inplace, FunctionPrinter(["gemm_no_inplace"])) |
1149 | 1146 |
|
1150 | 1147 |
|
1151 | | -def res_is_a(fgraph, var, op, maxclients=None): |
1152 | | - if maxclients is not None and var in fgraph.clients: |
1153 | | - retval = len(fgraph.get_clients(var)) <= maxclients |
1154 | | - else: |
1155 | | - retval = True |
1156 | | - |
1157 | | - return var.owner and var.owner.op == op and retval |
1158 | | - |
1159 | | - |
1160 | | -def _as_scalar(res, dtype=None): |
1161 | | - """Return ``None`` or a `TensorVariable` of float type""" |
1162 | | - if dtype is None: |
1163 | | - dtype = config.floatX |
1164 | | - if all(s == 1 for s in res.type.shape): |
1165 | | - while res.owner and isinstance(res.owner.op, DimShuffle): |
1166 | | - res = res.owner.inputs[0] |
1167 | | - # may still have some number of True's |
1168 | | - if res.type.ndim > 0: |
1169 | | - rval = res.dimshuffle() |
1170 | | - else: |
1171 | | - rval = res |
1172 | | - if rval.type.dtype in integer_dtypes: |
1173 | | - # We check that the upcast of res and dtype won't change dtype. |
1174 | | - # If dtype is float64, we will cast int64 to float64. |
1175 | | - # This is valid when res is a scalar used as input to a dot22 |
1176 | | - # as the cast of the scalar can be done before or after the dot22 |
1177 | | - # and this will give the same result. |
1178 | | - if pytensor.scalar.upcast(res.dtype, dtype) == dtype: |
1179 | | - return ptb.cast(rval, dtype) |
1180 | | - else: |
1181 | | - return None |
1182 | | - |
1183 | | - return rval |
1184 | | - |
1185 | | - |
1186 | | -def _is_real_matrix(res): |
1187 | | - return ( |
1188 | | - res.type.dtype in ("float16", "float32", "float64") |
1189 | | - and res.type.ndim == 2 |
1190 | | - and res.type.shape[0] != 1 |
1191 | | - and res.type.shape[1] != 1 |
1192 | | - ) # cope with tuple vs. list |
1193 | | - |
1194 | | - |
1195 | | -def _is_real_vector(res): |
1196 | | - return ( |
1197 | | - res.type.dtype in ("float16", "float32", "float64") |
1198 | | - and res.type.ndim == 1 |
1199 | | - and res.type.shape[0] != 1 |
1200 | | - ) |
1201 | | - |
1202 | | - |
1203 | | -def _beta_L_plus_alpha_M(fgraph, beta, L, alpha, M, recurse_flip=True): |
1204 | | - # print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip |
1205 | | - # EXPRESSION: (beta * L) + (alpha * M) |
1206 | | - |
1207 | | - # we've already checked the client counts, now just make the type check. |
1208 | | - # if res_is_a(M, _dot22, 1): |
1209 | | - if M.owner and M.owner.op == _dot22: |
1210 | | - Ml, Mr = M.owner.inputs |
1211 | | - rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)] |
1212 | | - return rval, M |
1213 | | - |
1214 | | - # it also might be the case that there is a dimshuffle between the + |
1215 | | - # and the dot22. local_dot_to_dot22 in particular will put in such things. |
1216 | | - if ( |
1217 | | - M.owner |
1218 | | - and isinstance(M.owner.op, DimShuffle) |
1219 | | - and M.owner.inputs[0].owner |
1220 | | - and isinstance(M.owner.inputs[0].owner.op, Dot22) |
1221 | | - ): |
1222 | | - MM = M.owner.inputs[0] |
1223 | | - if M.owner.op.new_order == (0,): |
1224 | | - # it is making a column MM into a vector |
1225 | | - MMl, MMr = MM.owner.inputs |
1226 | | - g = gemm_no_inplace(L.dimshuffle(0, "x"), alpha, MMl, MMr, beta) |
1227 | | - rval = [g.dimshuffle(0)] |
1228 | | - return rval, MM |
1229 | | - if M.owner.op.new_order == (1,): |
1230 | | - # it is making a row MM into a vector |
1231 | | - MMl, MMr = MM.owner.inputs |
1232 | | - g = gemm_no_inplace(L.dimshuffle("x", 0), alpha, MMl, MMr, beta) |
1233 | | - rval = [g.dimshuffle(1)] |
1234 | | - return rval, MM |
1235 | | - if len(M.owner.op.new_order) == 0: |
1236 | | - # it is making a row MM into a vector |
1237 | | - MMl, MMr = MM.owner.inputs |
1238 | | - g = gemm_no_inplace(L.dimshuffle("x", "x"), alpha, MMl, MMr, beta) |
1239 | | - rval = [g.dimshuffle()] |
1240 | | - return rval, MM |
1241 | | - |
1242 | | - if recurse_flip: |
1243 | | - return _beta_L_plus_alpha_M(fgraph, alpha, M, beta, L, recurse_flip=False) |
1244 | | - else: |
1245 | | - return False, False |
1246 | | - |
1247 | | - |
1248 | | -def _gemm_canonicalize(fgraph, r, scale, rval, maxclients): |
1249 | | - # Tries to interpret node as a sum of scalars * (vectors or matrices) |
1250 | | - def scaled(thing): |
1251 | | - if scale == 1: |
1252 | | - return thing |
1253 | | - if scale == -1 and thing.type.dtype != "bool": |
1254 | | - return -thing |
1255 | | - else: |
1256 | | - return scale * thing |
1257 | | - |
1258 | | - if not isinstance(r.type, TensorType): |
1259 | | - return None |
1260 | | - |
1261 | | - if (r.type.ndim not in (1, 2)) or r.type.dtype not in ( |
1262 | | - "float16", |
1263 | | - "float32", |
1264 | | - "float64", |
1265 | | - "complex64", |
1266 | | - "complex128", |
1267 | | - ): |
1268 | | - rval.append(scaled(r)) |
1269 | | - return rval |
1270 | | - |
1271 | | - if maxclients and len(fgraph.clients[r]) > maxclients: |
1272 | | - rval.append((scale, r)) |
1273 | | - return rval |
1274 | | - |
1275 | | - if r.owner and r.owner.op == sub: |
1276 | | - _gemm_canonicalize(fgraph, r.owner.inputs[0], scale, rval, 1) |
1277 | | - _gemm_canonicalize(fgraph, r.owner.inputs[1], -scale, rval, 1) |
1278 | | - |
1279 | | - elif r.owner and r.owner.op == add: |
1280 | | - for i in r.owner.inputs: |
1281 | | - _gemm_canonicalize(fgraph, i, scale, rval, 1) |
1282 | | - |
1283 | | - elif r.owner and r.owner.op == neg: |
1284 | | - _gemm_canonicalize(fgraph, r.owner.inputs[0], -scale, rval, 1) |
1285 | | - |
1286 | | - elif r.owner and r.owner.op == mul: |
1287 | | - scalars = [] |
1288 | | - vectors = [] |
1289 | | - matrices = [] |
1290 | | - for i in r.owner.inputs: |
1291 | | - if all(s == 1 for s in i.type.shape): |
1292 | | - while i.owner and isinstance(i.owner.op, DimShuffle): |
1293 | | - i = i.owner.inputs[0] |
1294 | | - if i.type.ndim > 0: |
1295 | | - scalars.append(i.dimshuffle()) |
1296 | | - else: |
1297 | | - scalars.append(i) |
1298 | | - elif _is_real_vector(i): |
1299 | | - vectors.append(i) |
1300 | | - elif _is_real_matrix(i): |
1301 | | - matrices.append(i) |
1302 | | - else: |
1303 | | - # just put the original arguments as in the base case |
1304 | | - rval.append((scale, r)) |
1305 | | - return rval |
1306 | | - if len(matrices) == 1: |
1307 | | - assert len(vectors) == 0 |
1308 | | - m = matrices[0] |
1309 | | - if len(scalars) == 0: |
1310 | | - _gemm_canonicalize(fgraph, m, scale, rval, 1) |
1311 | | - elif len(scalars) == 1: |
1312 | | - _gemm_canonicalize(fgraph, m, scaled(scalars[0]), rval, 1) |
1313 | | - else: |
1314 | | - _gemm_canonicalize( |
1315 | | - fgraph, m, mul(scaled(scalars[0]), *scalars[1:]), rval, 1 |
1316 | | - ) |
1317 | | - elif len(vectors) == 1: |
1318 | | - assert len(matrices) == 0 |
1319 | | - v = vectors[0] |
1320 | | - if len(scalars) == 0: |
1321 | | - _gemm_canonicalize(fgraph, v, scale, rval, 1) |
1322 | | - elif len(scalars) == 1: |
1323 | | - _gemm_canonicalize(fgraph, v, scaled(scalars[0]), rval, 1) |
1324 | | - else: |
1325 | | - _gemm_canonicalize( |
1326 | | - fgraph, v, mul(scaled(scalars[0]), *scalars[1:]), rval, 1 |
1327 | | - ) |
1328 | | - else: # lets not open this up |
1329 | | - rval.append((scale, r)) |
1330 | | - else: |
1331 | | - rval.append((scale, r)) |
1332 | | - return rval |
1333 | | - |
1334 | | - |
1335 | | -def _factor_canonicalized(lst): |
1336 | | - # remove duplicates from canonicalized list |
1337 | | - |
1338 | | - # we only delete out of the right end of the list, |
1339 | | - # once i has touched a list element, it is permantent |
1340 | | - lst = list(lst) |
1341 | | - # print 'FACTOR', lst |
1342 | | - # for t in lst: |
1343 | | - # if not isinstance(t, (list, tuple)): |
1344 | | - # t = (t,) |
1345 | | - # for e in t: |
1346 | | - # try: |
1347 | | - # pytensor.printing.debugprint(e) |
1348 | | - # except TypeError: |
1349 | | - # print e, type(e) |
1350 | | - i = 0 |
1351 | | - while i < len(lst) - 1: |
1352 | | - try: |
1353 | | - s_i, M_i = lst[i] |
1354 | | - except Exception: |
1355 | | - i += 1 |
1356 | | - continue |
1357 | | - |
1358 | | - j = i + 1 |
1359 | | - while j < len(lst): |
1360 | | - try: |
1361 | | - s_j, M_j = lst[j] |
1362 | | - except Exception: |
1363 | | - j += 1 |
1364 | | - continue |
1365 | | - |
1366 | | - if M_i is M_j: |
1367 | | - s_i = s_i + s_j |
1368 | | - lst[i] = (s_i, M_i) |
1369 | | - del lst[j] |
1370 | | - else: |
1371 | | - j += 1 |
1372 | | - i += 1 |
1373 | | - return lst |
1374 | | - |
1375 | | - |
1376 | | -def _gemm_from_factored_list(fgraph, lst): |
1377 | | - """ |
1378 | | - Returns None, or a list to replace node.outputs. |
1379 | | -
|
1380 | | - """ |
1381 | | - lst2 = [] |
1382 | | - # Remove the tuple that can't be cast correctly. |
1383 | | - # This can happen when we try to cast a complex to a real |
1384 | | - for sM in lst: |
1385 | | - # Make every pair in list have matching dtypes |
1386 | | - # sM can be a tuple of 2 elements or an PyTensor variable. |
1387 | | - if isinstance(sM, tuple): |
1388 | | - sm0, sm1 = sM |
1389 | | - sm0 = ptb.as_tensor_variable(sm0) |
1390 | | - if pytensor.scalar.upcast(sm0.dtype, sm1.dtype) == sm1.dtype: |
1391 | | - lst2.append((ptb.cast(sm0, sm1.dtype), sM[1])) |
1392 | | - |
1393 | | - lst = lst2 |
1394 | | - |
1395 | | - def item_to_var(t): |
1396 | | - try: |
1397 | | - s, M = t |
1398 | | - except Exception: |
1399 | | - return t |
1400 | | - if s == 1: |
1401 | | - return M |
1402 | | - if s == -1: |
1403 | | - return -M |
1404 | | - return s * M |
1405 | | - |
1406 | | - # Try every pair in the sM_list, trying to turn it into a gemm operation |
1407 | | - for i in range(len(lst) - 1): |
1408 | | - s_i, M_i = lst[i] |
1409 | | - |
1410 | | - for j in range(i + 1, len(lst)): |
1411 | | - s_j, M_j = lst[j] |
1412 | | - |
1413 | | - if not M_j.type.in_same_class(M_i.type): |
1414 | | - continue |
1415 | | - |
1416 | | - # print 'TRYING', (s_i, M_i, s_j, M_j) |
1417 | | - |
1418 | | - gemm_of_sM_list, old_dot22 = _beta_L_plus_alpha_M( |
1419 | | - fgraph, s_i, M_i, s_j, M_j |
1420 | | - ) |
1421 | | - # print 'GOT IT', gemm_of_sM_list |
1422 | | - if gemm_of_sM_list: |
1423 | | - assert len(gemm_of_sM_list) == 1 |
1424 | | - add_inputs = [ |
1425 | | - item_to_var(input) for k, input in enumerate(lst) if k not in (i, j) |
1426 | | - ] |
1427 | | - add_inputs.extend(gemm_of_sM_list) |
1428 | | - rval = [variadic_add(*add_inputs)] |
1429 | | - return rval, old_dot22 |
1430 | | - |
1431 | | - |
1432 | | -def _gemm_from_node2(fgraph, node): |
1433 | | - """ |
1434 | | -
|
1435 | | - TODO: In many expressions, there are many ways to turn it into a |
1436 | | - gemm. For example dot(a,b) + c + d. This function should return all |
1437 | | - of them, so that if one version of gemm causes a cycle in the graph, then |
1438 | | - another application of gemm can be tried. |
1439 | | -
|
1440 | | - """ |
1441 | | - lst = [] |
1442 | | - t0 = time.perf_counter() |
1443 | | - _gemm_canonicalize(fgraph, node.outputs[0], 1.0, lst, 0) |
1444 | | - t1 = time.perf_counter() |
1445 | | - |
1446 | | - if len(lst) > 1: |
1447 | | - lst = _factor_canonicalized(lst) |
1448 | | - t2 = time.perf_counter() |
1449 | | - rval = _gemm_from_factored_list(fgraph, lst) |
1450 | | - t3 = time.perf_counter() |
1451 | | - |
1452 | | - # It can happen that _factor_canonicalized and |
1453 | | - # _gemm_from_factored_list return a node with an incorrect |
1454 | | - # type. This happens in particular when one of the scalar |
1455 | | - # factors forces the upcast of the whole expression. In that |
1456 | | - # case, we simply skip that candidate for Gemm. This was |
1457 | | - # discussed in |
1458 | | - # http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5, |
1459 | | - # but never made it into a trac ticket. |
1460 | | - |
1461 | | - if rval and rval[0][0].type.in_same_class(node.outputs[0].type): |
1462 | | - return rval, t1 - t0, t2 - t1, t3 - t2 |
1463 | | - |
1464 | | - return None, t1 - t0, 0, 0 |
1465 | | - |
1466 | | - |
1467 | 1148 | class Dot22(GemmRelated): |
1468 | 1149 | """Compute a matrix-matrix product. |
1469 | 1150 |
|
|
0 commit comments