|
44 | 44 |
|
45 | 45 | import re |
46 | 46 | from collections.abc import Iterable, Mapping, Set as AbstractSet |
| 47 | +from functools import reduce |
47 | 48 | from typing import ( |
48 | 49 | TYPE_CHECKING, |
49 | 50 | Any, |
@@ -266,8 +267,7 @@ def _get_op_nflops(self, name: str) -> ArithmeticExpression: |
266 | 267 | try: |
267 | 268 | return self.op_name_to_num_flops[name] |
268 | 269 | except KeyError: |
269 | | - from pymbolic import var |
270 | | - result = var("nflops")(var(name)) |
| 270 | + result = OpFlops(name) |
271 | 271 | self.op_name_to_num_flops[name] = result |
272 | 272 | return result |
273 | 273 |
|
@@ -480,9 +480,42 @@ class TypeCast(ExpressionBase): |
480 | 480 | dtype: np.dtype[Any] |
481 | 481 | inner_expr: ScalarExpression |
482 | 482 |
|
| 483 | + |
| 484 | +@expr_dataclass() |
| 485 | +class OpFlops(prim.AlgebraicLeaf): |
| 486 | + """ |
| 487 | + Placeholder flop count for an operator. |
| 488 | +
|
| 489 | + .. autoattribute:: op |
| 490 | + """ |
| 491 | + op: str |
| 492 | + |
483 | 493 | # }}} |
484 | 494 |
|
485 | 495 |
|
| 496 | +class OpFlopsCollector(CombineMapper[frozenset[OpFlops], []]): |
| 497 | + """ |
| 498 | + Constructs a :class:`frozenset` containing all instances of |
| 499 | + :class:`pytato.scalar_expr.OpFlops` found in a scalar expression. |
| 500 | + """ |
| 501 | + @override |
| 502 | + def combine( |
| 503 | + self, values: Iterable[frozenset[OpFlops]]) -> frozenset[OpFlops]: |
| 504 | + return reduce( |
| 505 | + lambda x, y: x.union(y), |
| 506 | + values, |
| 507 | + cast("frozenset[OpFlops]", frozenset())) |
| 508 | + |
| 509 | + @override |
| 510 | + def map_algebraic_leaf( |
| 511 | + self, expr: prim.AlgebraicLeaf) -> frozenset[OpFlops]: |
| 512 | + return frozenset([expr]) if isinstance(expr, OpFlops) else frozenset() |
| 513 | + |
| 514 | + @override |
| 515 | + def map_constant(self, expr: object) -> frozenset[OpFlops]: |
| 516 | + return frozenset() |
| 517 | + |
| 518 | + |
486 | 519 | class InductionVariableCollector(CombineMapper[AbstractSet[str], []]): |
487 | 520 | def combine(self, values: Iterable[AbstractSet[str]]) -> frozenset[str]: |
488 | 521 | from functools import reduce |
|
0 commit comments