Skip to content

Commit 9504a47

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Mosaic GPU][NFC] Rename EquationSystem to ConstraintSystem.
I decided to alias the constraints module as `cs` because `constraints` is widely used as a variable name and it resulted in a lot of conflicts. PiperOrigin-RevId: 836308644
1 parent f1601fb commit 9504a47

File tree

6 files changed

+711
-678
lines changed

6 files changed

+711
-678
lines changed

jax/experimental/mosaic/gpu/equations.py renamed to jax/experimental/mosaic/gpu/constraints.py

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Defines expressions and equations over layouts."""
15+
"""Defines expressions and constraints over layouts."""
1616

1717
# mypy has been causing more problems than it solves here. Disable it for these
1818
# files. We have pytype checks anyway.
@@ -379,7 +379,7 @@ class Relayout:
379379
do not ever plan to support it.
380380
381381
Modeling this constraint this way is helpful, in order to allow pruning
382-
inefficient solutions when attempting to solve an equation system.
382+
inefficient solutions when attempting to solve a constraint system.
383383
"""
384384

385385
source: Expression
@@ -615,16 +615,14 @@ def reduce_constraint(
615615

616616

617617
@dataclasses.dataclass
618-
class EquationSystem:
619-
"""An equation system contains a set of equations and assignments.
618+
class ConstraintSystem:
619+
"""A constraint system contains a set of constraints and assignments.
620620
621621
Assignments assign constant values to variables in the system (bound
622-
variables). Equations describe relationships between variables, and can be
623-
used to determine assignments for unknown (free) variables.
624-
625-
Constraints are used to check predicates that must hold for the assignments to
626-
be valid.
622+
variables). Constraints describe relationships between variables that must be
623+
upheld, and can be used to determine assignments for unknown (free) variables.
627624
"""
625+
628626
assignments: dict[Variable, Constant] = dataclasses.field(
629627
default_factory=dict
630628
)
@@ -677,17 +675,19 @@ def extract_variables(expr: Expression) -> None:
677675
assert_never(never)
678676
return free_variables
679677

680-
def __and__(self, other: EquationSystem) -> EquationSystem | Unsatisfiable:
678+
def __and__(
679+
self, other: ConstraintSystem
680+
) -> ConstraintSystem | Unsatisfiable:
681681
for variable, assignment in self.assignments.items():
682682
if variable in other.assignments and assignment != other.assignments[variable]:
683683
return Unsatisfiable()
684-
return EquationSystem(
684+
return ConstraintSystem(
685685
assignments=self.assignments | other.assignments,
686686
constraints=[*self.constraints, *other.constraints],
687687
)
688688

689689
def __str__(self):
690-
r = "EquationSystem\n"
690+
r = "ConstraintSystem\n"
691691
r += " assignments:\n"
692692
for assignment, constant in self.assignments.items():
693693
r += f" {assignment}{constant}\n"
@@ -696,9 +696,11 @@ def __str__(self):
696696
r += f" {constraint}\n"
697697
return r
698698

699+
699700
@final
700701
class Unsatisfiable:
701-
def __and__(self, other: EquationSystem | Unsatisfiable) -> Unsatisfiable:
702+
703+
def __and__(self, other: ConstraintSystem | Unsatisfiable) -> Unsatisfiable:
702704
return self
703705

704706

@@ -747,8 +749,8 @@ def is_constant_splat(e) -> bool:
747749

748750

749751
def saturate_distinct_from_splat(
750-
equation_system: EquationSystem,
751-
) -> EquationSystem | Unsatisfiable:
752+
constraint_system: ConstraintSystem,
753+
) -> ConstraintSystem | Unsatisfiable:
752754
"""Adds transitive NotOfType constraints for all non-splat variables.
753755
754756
Given `n` variables `l0`, ... `l{n-1}`, and a set of relayouts
@@ -759,13 +761,13 @@ def saturate_distinct_from_splat(
759761
This helps us quickly conclude that a system is unsatisfiable in cases where
760762
a non-splat variable is transitively relaid out into a splat layout.
761763
"""
762-
non_splat = non_splat_variables(equation_system.constraints)
764+
non_splat = non_splat_variables(constraint_system.constraints)
763765
new_constraints: list[Constraint] = []
764766
new_non_splat_found = len(non_splat) > 0
765767

766768
while new_non_splat_found:
767769
new_non_splat_found = False
768-
for constraint in equation_system.constraints:
770+
for constraint in constraint_system.constraints:
769771
match constraint:
770772
case Relayout(source=source, target=target):
771773
if (
@@ -778,19 +780,19 @@ def saturate_distinct_from_splat(
778780
new_constraints.append(NotOfType(target, fa.WGSplatFragLayout))
779781
case _:
780782
pass
781-
return equation_system & EquationSystem(constraints=new_constraints)
783+
return constraint_system & ConstraintSystem(constraints=new_constraints)
782784

783785

784786
def compute_transitively_equal_vars(
785-
system: EquationSystem,
787+
system: ConstraintSystem,
786788
) -> dict[Variable, list[Variable]]:
787-
"""Computes all transitively equal variables in an equation system.
789+
"""Computes all transitively equal variables in a constraint system.
788790
789-
The output dictionary maps each variable that appears in equations in the
790-
equation system to all the variables it is transitively equal to.
791+
The output dictionary maps each variable that appears in constraints in the
792+
constraint system to all the variables it is transitively equal to.
791793
"""
792794
# The equality relations between variables form a graph where variables are
793-
# nodes and an equation `v1 == v2` forms an edge. All variables in a
795+
# nodes and a constraint `v1 == v2` forms an edge. All variables in a
794796
# connected component are transitively equal. We use a Union-Find data
795797
# structure with path compression to efficiently find these connected
796798
# components (i.e., equivalence classes).
@@ -833,8 +835,8 @@ def union(v1: Variable, v2: Variable):
833835

834836

835837
def saturate_divides_constraints_for_equal_vars(
836-
system: EquationSystem,
837-
) -> EquationSystem:
838+
system: ConstraintSystem,
839+
) -> ConstraintSystem:
838840
"""Saturates Divides constraints between all transitively equal vars.
839841
"""
840842
equal_vars = compute_transitively_equal_vars(system)
@@ -882,17 +884,17 @@ def merge_divides_constraints(constraints: Sequence[Constraint]) -> list[Constra
882884

883885

884886
def _reduce_system_once(
885-
equation_system: EquationSystem,
886-
) -> EquationSystem | Unsatisfiable | None:
887-
"""Performs one reduction step over each equation in an equation system.
887+
constraint_system: ConstraintSystem,
888+
) -> ConstraintSystem | Unsatisfiable | None:
889+
"""Performs one reduction step over each constraint in a constraint system.
888890
889891
Returns:
890-
- Unsatisfiable(): if the equation system is unsatisfiable.
891-
- A new equation system if any equation was reduced.
892-
- None: if the equation system is not known unsatisfiable, but hasn't been
892+
- Unsatisfiable(): if the constraint system is unsatisfiable.
893+
- A new constraint system if any constraint was reduced.
894+
- None: if the constraint system is not known unsatisfiable, but hasn't been
893895
reduced.
894896
"""
895-
assignments = equation_system.assignments
897+
assignments = constraint_system.assignments
896898
constraints: list[Constraint] = []
897899
changed = False
898900

@@ -902,7 +904,7 @@ def try_assign(var: Variable, cst: Constant) -> bool:
902904
assignments[var] = cst
903905
return True
904906

905-
for constraint in equation_system.constraints:
907+
for constraint in constraint_system.constraints:
906908
match reduce_constraint(constraint, assignments):
907909
case Unsatisfiable():
908910
return Unsatisfiable()
@@ -930,29 +932,31 @@ def try_assign(var: Variable, cst: Constant) -> bool:
930932
return Unsatisfiable()
931933

932934
if changed:
933-
return EquationSystem(
934-
assignments=assignments | equation_system.assignments,
935+
return ConstraintSystem(
936+
assignments=assignments | constraint_system.assignments,
935937
constraints=constraints,
936938
)
937939
return None
938940

939941

940-
def reduce(equation_system: EquationSystem) -> EquationSystem | Unsatisfiable:
941-
"""Reduces an equation system until it can no longer be reduced.
942+
def reduce(
943+
constraint_system: ConstraintSystem,
944+
) -> ConstraintSystem | Unsatisfiable:
945+
"""Reduces a constraint system until it can no longer be reduced.
942946
943947
Returns:
944-
- Unsatisfiable(): if the equation system is unsatisfiable.
945-
- The maximally reduced equation system otherwise.
948+
- Unsatisfiable(): if the constraint system is unsatisfiable.
949+
- The maximally reduced constraint system otherwise.
946950
"""
947951
while True:
948-
match _reduce_system_once(equation_system):
952+
match _reduce_system_once(constraint_system):
949953
case None:
950954
break
951955
case Unsatisfiable():
952956
return Unsatisfiable()
953-
case EquationSystem() as new_system:
954-
equation_system = new_system
957+
case ConstraintSystem() as new_system:
958+
constraint_system = new_system
955959
case _ as never:
956960
assert_never(never)
957961

958-
return equation_system
962+
return constraint_system

0 commit comments

Comments
 (0)