Skip to content

Commit 0a66671

Browse files
committed
add placeholder class for operator flop counts that aren't specified
1 parent 09a71f2 commit 0a66671

File tree

2 files changed

+44
-8
lines changed

2 files changed

+44
-8
lines changed

pytato/analysis/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -806,13 +806,16 @@ def _get_own_flop_count(self, expr: Array) -> int:
806806
return 0
807807
nflops = self.scalar_flop_counter(to_index_lambda(expr).expr)
808808
if not isinstance(nflops, int):
809-
from pytato.scalar_expr import InputGatherer as ScalarInputGatherer
810-
var_names: set[str] = set(ScalarInputGatherer()(nflops))
811-
var_names.discard("nflops")
812-
if var_names:
813-
raise UndefinedOpFlopCountError(next(iter(var_names))) from None
809+
# Restricting to numerical result here because the flop counters that use
810+
# this mapper subsequently multiply the result by things that are
811+
# potentially arrays (e.g., shape components), and arrays and scalar
812+
# expressions are not interoperable
813+
from pytato.scalar_expr import OpFlops, OpFlopsCollector
814+
op_flops: frozenset[OpFlops] = OpFlopsCollector()(nflops)
815+
if op_flops:
816+
raise UndefinedOpFlopCountError(next(iter(op_flops)).op)
814817
else:
815-
raise AssertionError from None
818+
raise AssertionError
816819
return nflops
817820

818821
@override

pytato/scalar_expr.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
import re
4646
from collections.abc import Iterable, Mapping, Set as AbstractSet
47+
from functools import reduce
4748
from typing import (
4849
TYPE_CHECKING,
4950
Any,
@@ -266,8 +267,7 @@ def _get_op_nflops(self, name: str) -> ArithmeticExpression:
266267
try:
267268
return self.op_name_to_num_flops[name]
268269
except KeyError:
269-
from pymbolic import var
270-
result = var("nflops")(var(name))
270+
result = OpFlops(name)
271271
self.op_name_to_num_flops[name] = result
272272
return result
273273

@@ -480,9 +480,42 @@ class TypeCast(ExpressionBase):
480480
dtype: np.dtype[Any]
481481
inner_expr: ScalarExpression
482482

483+
484+
@expr_dataclass()
485+
class OpFlops(prim.AlgebraicLeaf):
486+
"""
487+
Placeholder flop count for an operator.
488+
489+
.. autoattribute:: op
490+
"""
491+
op: str
492+
483493
# }}}
484494

485495

496+
class OpFlopsCollector(CombineMapper[frozenset[OpFlops], []]):
497+
"""
498+
Constructs a :class:`frozenset` containing all instances of
499+
:class:`pytato.scalar_expr.OpFlops` found in a scalar expression.
500+
"""
501+
@override
502+
def combine(
503+
self, values: Iterable[frozenset[OpFlops]]) -> frozenset[OpFlops]:
504+
return reduce(
505+
lambda x, y: x.union(y),
506+
values,
507+
cast("frozenset[OpFlops]", frozenset()))
508+
509+
@override
510+
def map_algebraic_leaf(
511+
self, expr: prim.AlgebraicLeaf) -> frozenset[OpFlops]:
512+
return frozenset([expr]) if isinstance(expr, OpFlops) else frozenset()
513+
514+
@override
515+
def map_constant(self, expr: object) -> frozenset[OpFlops]:
516+
return frozenset()
517+
518+
486519
class InductionVariableCollector(CombineMapper[AbstractSet[str], []]):
487520
def combine(self, values: Iterable[AbstractSet[str]]) -> frozenset[str]:
488521
from functools import reduce

0 commit comments

Comments
 (0)