Skip to content

Commit 35e87e0

Browse files
committed
Remove duplicated BLAS rewriting code
Accidentally introduced in c655b02 Also move tests to the rewriting test file
1 parent a0fe30d commit 35e87e0

File tree

3 files changed

+165
-466
lines changed

3 files changed

+165
-466
lines changed

pytensor/tensor/blas.py

Lines changed: 1 addition & 320 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@
7979
import logging
8080
import os
8181
import shlex
82-
import time
8382
from pathlib import Path
8483

8584
import numpy as np
@@ -103,10 +102,8 @@
103102
from pytensor.tensor import basic as ptb
104103
from pytensor.tensor.basic import expand_dims
105104
from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
106-
from pytensor.tensor.elemwise import DimShuffle
107-
from pytensor.tensor.math import add, mul, neg, sub, variadic_add
108105
from pytensor.tensor.shape import shape_padright, specify_broadcastable
109-
from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor
106+
from pytensor.tensor.type import DenseTensorType, tensor
110107

111108

112109
_logger = logging.getLogger("pytensor.tensor.blas")
@@ -1148,322 +1145,6 @@ def c_code_cache_version(self):
11481145
pprint.assign(gemm_no_inplace, FunctionPrinter(["gemm_no_inplace"]))
11491146

11501147

1151-
def res_is_a(fgraph, var, op, maxclients=None):
1152-
if maxclients is not None and var in fgraph.clients:
1153-
retval = len(fgraph.get_clients(var)) <= maxclients
1154-
else:
1155-
retval = True
1156-
1157-
return var.owner and var.owner.op == op and retval
1158-
1159-
1160-
def _as_scalar(res, dtype=None):
1161-
"""Return ``None`` or a `TensorVariable` of float type"""
1162-
if dtype is None:
1163-
dtype = config.floatX
1164-
if all(s == 1 for s in res.type.shape):
1165-
while res.owner and isinstance(res.owner.op, DimShuffle):
1166-
res = res.owner.inputs[0]
1167-
# may still have some number of True's
1168-
if res.type.ndim > 0:
1169-
rval = res.dimshuffle()
1170-
else:
1171-
rval = res
1172-
if rval.type.dtype in integer_dtypes:
1173-
# We check that the upcast of res and dtype won't change dtype.
1174-
# If dtype is float64, we will cast int64 to float64.
1175-
# This is valid when res is a scalar used as input to a dot22
1176-
# as the cast of the scalar can be done before or after the dot22
1177-
# and this will give the same result.
1178-
if pytensor.scalar.upcast(res.dtype, dtype) == dtype:
1179-
return ptb.cast(rval, dtype)
1180-
else:
1181-
return None
1182-
1183-
return rval
1184-
1185-
1186-
def _is_real_matrix(res):
1187-
return (
1188-
res.type.dtype in ("float16", "float32", "float64")
1189-
and res.type.ndim == 2
1190-
and res.type.shape[0] != 1
1191-
and res.type.shape[1] != 1
1192-
) # cope with tuple vs. list
1193-
1194-
1195-
def _is_real_vector(res):
1196-
return (
1197-
res.type.dtype in ("float16", "float32", "float64")
1198-
and res.type.ndim == 1
1199-
and res.type.shape[0] != 1
1200-
)
1201-
1202-
1203-
def _beta_L_plus_alpha_M(fgraph, beta, L, alpha, M, recurse_flip=True):
1204-
# print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
1205-
# EXPRESSION: (beta * L) + (alpha * M)
1206-
1207-
# we've already checked the client counts, now just make the type check.
1208-
# if res_is_a(M, _dot22, 1):
1209-
if M.owner and M.owner.op == _dot22:
1210-
Ml, Mr = M.owner.inputs
1211-
rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)]
1212-
return rval, M
1213-
1214-
# it also might be the case that there is a dimshuffle between the +
1215-
# and the dot22. local_dot_to_dot22 in particular will put in such things.
1216-
if (
1217-
M.owner
1218-
and isinstance(M.owner.op, DimShuffle)
1219-
and M.owner.inputs[0].owner
1220-
and isinstance(M.owner.inputs[0].owner.op, Dot22)
1221-
):
1222-
MM = M.owner.inputs[0]
1223-
if M.owner.op.new_order == (0,):
1224-
# it is making a column MM into a vector
1225-
MMl, MMr = MM.owner.inputs
1226-
g = gemm_no_inplace(L.dimshuffle(0, "x"), alpha, MMl, MMr, beta)
1227-
rval = [g.dimshuffle(0)]
1228-
return rval, MM
1229-
if M.owner.op.new_order == (1,):
1230-
# it is making a row MM into a vector
1231-
MMl, MMr = MM.owner.inputs
1232-
g = gemm_no_inplace(L.dimshuffle("x", 0), alpha, MMl, MMr, beta)
1233-
rval = [g.dimshuffle(1)]
1234-
return rval, MM
1235-
if len(M.owner.op.new_order) == 0:
1236-
# it is making a row MM into a vector
1237-
MMl, MMr = MM.owner.inputs
1238-
g = gemm_no_inplace(L.dimshuffle("x", "x"), alpha, MMl, MMr, beta)
1239-
rval = [g.dimshuffle()]
1240-
return rval, MM
1241-
1242-
if recurse_flip:
1243-
return _beta_L_plus_alpha_M(fgraph, alpha, M, beta, L, recurse_flip=False)
1244-
else:
1245-
return False, False
1246-
1247-
1248-
def _gemm_canonicalize(fgraph, r, scale, rval, maxclients):
1249-
# Tries to interpret node as a sum of scalars * (vectors or matrices)
1250-
def scaled(thing):
1251-
if scale == 1:
1252-
return thing
1253-
if scale == -1 and thing.type.dtype != "bool":
1254-
return -thing
1255-
else:
1256-
return scale * thing
1257-
1258-
if not isinstance(r.type, TensorType):
1259-
return None
1260-
1261-
if (r.type.ndim not in (1, 2)) or r.type.dtype not in (
1262-
"float16",
1263-
"float32",
1264-
"float64",
1265-
"complex64",
1266-
"complex128",
1267-
):
1268-
rval.append(scaled(r))
1269-
return rval
1270-
1271-
if maxclients and len(fgraph.clients[r]) > maxclients:
1272-
rval.append((scale, r))
1273-
return rval
1274-
1275-
if r.owner and r.owner.op == sub:
1276-
_gemm_canonicalize(fgraph, r.owner.inputs[0], scale, rval, 1)
1277-
_gemm_canonicalize(fgraph, r.owner.inputs[1], -scale, rval, 1)
1278-
1279-
elif r.owner and r.owner.op == add:
1280-
for i in r.owner.inputs:
1281-
_gemm_canonicalize(fgraph, i, scale, rval, 1)
1282-
1283-
elif r.owner and r.owner.op == neg:
1284-
_gemm_canonicalize(fgraph, r.owner.inputs[0], -scale, rval, 1)
1285-
1286-
elif r.owner and r.owner.op == mul:
1287-
scalars = []
1288-
vectors = []
1289-
matrices = []
1290-
for i in r.owner.inputs:
1291-
if all(s == 1 for s in i.type.shape):
1292-
while i.owner and isinstance(i.owner.op, DimShuffle):
1293-
i = i.owner.inputs[0]
1294-
if i.type.ndim > 0:
1295-
scalars.append(i.dimshuffle())
1296-
else:
1297-
scalars.append(i)
1298-
elif _is_real_vector(i):
1299-
vectors.append(i)
1300-
elif _is_real_matrix(i):
1301-
matrices.append(i)
1302-
else:
1303-
# just put the original arguments as in the base case
1304-
rval.append((scale, r))
1305-
return rval
1306-
if len(matrices) == 1:
1307-
assert len(vectors) == 0
1308-
m = matrices[0]
1309-
if len(scalars) == 0:
1310-
_gemm_canonicalize(fgraph, m, scale, rval, 1)
1311-
elif len(scalars) == 1:
1312-
_gemm_canonicalize(fgraph, m, scaled(scalars[0]), rval, 1)
1313-
else:
1314-
_gemm_canonicalize(
1315-
fgraph, m, mul(scaled(scalars[0]), *scalars[1:]), rval, 1
1316-
)
1317-
elif len(vectors) == 1:
1318-
assert len(matrices) == 0
1319-
v = vectors[0]
1320-
if len(scalars) == 0:
1321-
_gemm_canonicalize(fgraph, v, scale, rval, 1)
1322-
elif len(scalars) == 1:
1323-
_gemm_canonicalize(fgraph, v, scaled(scalars[0]), rval, 1)
1324-
else:
1325-
_gemm_canonicalize(
1326-
fgraph, v, mul(scaled(scalars[0]), *scalars[1:]), rval, 1
1327-
)
1328-
else: # lets not open this up
1329-
rval.append((scale, r))
1330-
else:
1331-
rval.append((scale, r))
1332-
return rval
1333-
1334-
1335-
def _factor_canonicalized(lst):
1336-
# remove duplicates from canonicalized list
1337-
1338-
# we only delete out of the right end of the list,
1339-
# once i has touched a list element, it is permantent
1340-
lst = list(lst)
1341-
# print 'FACTOR', lst
1342-
# for t in lst:
1343-
# if not isinstance(t, (list, tuple)):
1344-
# t = (t,)
1345-
# for e in t:
1346-
# try:
1347-
# pytensor.printing.debugprint(e)
1348-
# except TypeError:
1349-
# print e, type(e)
1350-
i = 0
1351-
while i < len(lst) - 1:
1352-
try:
1353-
s_i, M_i = lst[i]
1354-
except Exception:
1355-
i += 1
1356-
continue
1357-
1358-
j = i + 1
1359-
while j < len(lst):
1360-
try:
1361-
s_j, M_j = lst[j]
1362-
except Exception:
1363-
j += 1
1364-
continue
1365-
1366-
if M_i is M_j:
1367-
s_i = s_i + s_j
1368-
lst[i] = (s_i, M_i)
1369-
del lst[j]
1370-
else:
1371-
j += 1
1372-
i += 1
1373-
return lst
1374-
1375-
1376-
def _gemm_from_factored_list(fgraph, lst):
1377-
"""
1378-
Returns None, or a list to replace node.outputs.
1379-
1380-
"""
1381-
lst2 = []
1382-
# Remove the tuple that can't be cast correctly.
1383-
# This can happen when we try to cast a complex to a real
1384-
for sM in lst:
1385-
# Make every pair in list have matching dtypes
1386-
# sM can be a tuple of 2 elements or an PyTensor variable.
1387-
if isinstance(sM, tuple):
1388-
sm0, sm1 = sM
1389-
sm0 = ptb.as_tensor_variable(sm0)
1390-
if pytensor.scalar.upcast(sm0.dtype, sm1.dtype) == sm1.dtype:
1391-
lst2.append((ptb.cast(sm0, sm1.dtype), sM[1]))
1392-
1393-
lst = lst2
1394-
1395-
def item_to_var(t):
1396-
try:
1397-
s, M = t
1398-
except Exception:
1399-
return t
1400-
if s == 1:
1401-
return M
1402-
if s == -1:
1403-
return -M
1404-
return s * M
1405-
1406-
# Try every pair in the sM_list, trying to turn it into a gemm operation
1407-
for i in range(len(lst) - 1):
1408-
s_i, M_i = lst[i]
1409-
1410-
for j in range(i + 1, len(lst)):
1411-
s_j, M_j = lst[j]
1412-
1413-
if not M_j.type.in_same_class(M_i.type):
1414-
continue
1415-
1416-
# print 'TRYING', (s_i, M_i, s_j, M_j)
1417-
1418-
gemm_of_sM_list, old_dot22 = _beta_L_plus_alpha_M(
1419-
fgraph, s_i, M_i, s_j, M_j
1420-
)
1421-
# print 'GOT IT', gemm_of_sM_list
1422-
if gemm_of_sM_list:
1423-
assert len(gemm_of_sM_list) == 1
1424-
add_inputs = [
1425-
item_to_var(input) for k, input in enumerate(lst) if k not in (i, j)
1426-
]
1427-
add_inputs.extend(gemm_of_sM_list)
1428-
rval = [variadic_add(*add_inputs)]
1429-
return rval, old_dot22
1430-
1431-
1432-
def _gemm_from_node2(fgraph, node):
1433-
"""
1434-
1435-
TODO: In many expressions, there are many ways to turn it into a
1436-
gemm. For example dot(a,b) + c + d. This function should return all
1437-
of them, so that if one version of gemm causes a cycle in the graph, then
1438-
another application of gemm can be tried.
1439-
1440-
"""
1441-
lst = []
1442-
t0 = time.perf_counter()
1443-
_gemm_canonicalize(fgraph, node.outputs[0], 1.0, lst, 0)
1444-
t1 = time.perf_counter()
1445-
1446-
if len(lst) > 1:
1447-
lst = _factor_canonicalized(lst)
1448-
t2 = time.perf_counter()
1449-
rval = _gemm_from_factored_list(fgraph, lst)
1450-
t3 = time.perf_counter()
1451-
1452-
# It can happen that _factor_canonicalized and
1453-
# _gemm_from_factored_list return a node with an incorrect
1454-
# type. This happens in particular when one of the scalar
1455-
# factors forces the upcast of the whole expression. In that
1456-
# case, we simply skip that candidate for Gemm. This was
1457-
# discussed in
1458-
# http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5,
1459-
# but never made it into a trac ticket.
1460-
1461-
if rval and rval[0][0].type.in_same_class(node.outputs[0].type):
1462-
return rval, t1 - t0, t2 - t1, t3 - t2
1463-
1464-
return None, t1 - t0, 0, 0
1465-
1466-
14671148
class Dot22(GemmRelated):
14681149
"""Compute a matrix-matrix product.
14691150

0 commit comments

Comments
 (0)