Skip to content

Commit 6b83414

Browse files
authored
Dedupe rewrite constraints without sorting (#5864)
Dedupe rewrite constraints by consuming them by their LHS from the map of rewrite values, and dropping any LHS that we see more than once. This essentially uses the map to track which LHS we have seen in place of sorting the rewrite constraints by the LHS.
1 parent 64c31a6 commit 6b83414

File tree

1 file changed

+69
-151
lines changed

1 file changed

+69
-151
lines changed

toolchain/check/facet_type.cpp

Lines changed: 69 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -263,18 +263,6 @@ struct FacetTypeConstraintValue {
263263
SemIR::ElementIndex access_index;
264264
SemIR::SpecificInterfaceId specific_interface_id;
265265

266-
friend auto operator<=>(const FacetTypeConstraintValue& lhs,
267-
const FacetTypeConstraintValue& rhs)
268-
-> std::weak_ordering {
269-
if (lhs.entity_name_id != rhs.entity_name_id) {
270-
return lhs.entity_name_id.index <=> rhs.entity_name_id.index;
271-
}
272-
if (lhs.access_index != rhs.access_index) {
273-
return lhs.access_index.index <=> rhs.access_index.index;
274-
}
275-
return lhs.specific_interface_id.index <=> rhs.specific_interface_id.index;
276-
}
277-
278266
friend auto operator==(const FacetTypeConstraintValue& lhs,
279267
const FacetTypeConstraintValue& rhs) -> bool = default;
280268
};
@@ -298,72 +286,30 @@ static auto GetFacetTypeConstraintValue(Context& context,
298286
return std::nullopt;
299287
}
300288

301-
// Returns an ordering between two values in a rewrite constraint. Two
289+
// Returns true if two values in a rewrite constraint are equivalent. Two
302290
// `ImplWitnessAccess` instructions that refer to the same associated constant
303-
// through the same facet value are treated as equivalent. Otherwise, the
304-
// ordering is somewhat arbitrary with `ImplWitnessAccess` instructions coming
305-
// first.
291+
// through the same facet value are treated as equivalent.
306292
static auto CompareFacetTypeConstraintValues(Context& context,
307293
SemIR::InstId lhs_id,
308-
SemIR::InstId rhs_id)
309-
-> std::weak_ordering {
294+
SemIR::InstId rhs_id) -> bool {
310295
if (lhs_id == rhs_id) {
311-
return std::weak_ordering::equivalent;
296+
return true;
312297
}
313298

314299
auto lhs_access = context.insts().TryGetAs<SemIR::ImplWitnessAccess>(lhs_id);
315300
auto rhs_access = context.insts().TryGetAs<SemIR::ImplWitnessAccess>(rhs_id);
316301
if (lhs_access && rhs_access) {
317302
auto lhs_access_value = GetFacetTypeConstraintValue(context, *lhs_access);
318303
auto rhs_access_value = GetFacetTypeConstraintValue(context, *rhs_access);
319-
if (lhs_access_value && rhs_access_value) {
320-
return *lhs_access_value <=> *rhs_access_value;
321-
}
322-
323304
// We do *not* want to get the evaluated result of `ImplWitnessAccess` here,
324305
// we want to keep them as a reference to an associated constant for the
325306
// resolution phase.
326-
return lhs_id.index <=> rhs_id.index;
307+
return lhs_access_value && rhs_access_value &&
308+
*lhs_access_value == *rhs_access_value;
327309
}
328310

329-
// ImplWitnessAccess sorts before other instructions.
330-
if (lhs_access) {
331-
return std::weak_ordering::less;
332-
}
333-
if (rhs_access) {
334-
return std::weak_ordering::greater;
335-
}
336-
337-
return context.constant_values().GetConstantInstId(lhs_id).index <=>
338-
context.constant_values().GetConstantInstId(rhs_id).index;
339-
}
340-
341-
// Sort and dedupe the rewrite constraints, with accesses to the same associated
342-
// constants through the same facet value being treated as equivalent.
343-
static auto SortAndDedupeRewriteConstraints(
344-
Context& context,
345-
llvm::SmallVector<SemIR::FacetTypeInfo::RewriteConstraint>& rewrites) {
346-
auto ord = [&](const SemIR::FacetTypeInfo::RewriteConstraint& a,
347-
const SemIR::FacetTypeInfo::RewriteConstraint& b) {
348-
auto lhs = CompareFacetTypeConstraintValues(context, a.lhs_id, b.lhs_id);
349-
if (lhs != std::weak_ordering::equivalent) {
350-
return lhs;
351-
}
352-
auto rhs = CompareFacetTypeConstraintValues(context, a.rhs_id, b.rhs_id);
353-
return rhs;
354-
};
355-
356-
auto less = [&](const SemIR::FacetTypeInfo::RewriteConstraint& a,
357-
const SemIR::FacetTypeInfo::RewriteConstraint& b) {
358-
return ord(a, b) == std::weak_ordering::less;
359-
};
360-
llvm::stable_sort(rewrites, less);
361-
362-
auto eq = [&](const SemIR::FacetTypeInfo::RewriteConstraint& a,
363-
const SemIR::FacetTypeInfo::RewriteConstraint& b) {
364-
return ord(a, b) == std::weak_ordering::equivalent;
365-
};
366-
rewrites.erase(llvm::unique(rewrites, eq), rewrites.end());
311+
return context.constant_values().GetConstantInstId(lhs_id) ==
312+
context.constant_values().GetConstantInstId(rhs_id);
367313
}
368314

369315
// A mapping of each associated constant (represented as `ImplWitnessAccess`) to
@@ -403,19 +349,12 @@ class AccessRewriteValues {
403349
}
404350
}
405351

406-
auto SetFullyRewritten(Value& value, SemIR::InstId rewritten_to_inst_id)
352+
auto SetFullyRewritten(Context& context, Value& value, SemIR::InstId inst_id)
407353
-> void {
408-
// TODO: If state == FullyRewrtten and the inst id is different (according
409-
// to `CompareFacetTypeConstraintValues`), we can diagnose writing two
410-
// different values for the same associated constant immediately?
411-
//
412-
// TODO: Once the above is done, we don't need to do the SortAndDedupe
413-
// step in ResolveFacetTypeRewriteConstraints()? We can just convert this
414-
// `map_` into a new set of `RewriteConstraint`s which will already be
415-
// deduped.
416-
if (value.state == BeingRewritten) {
417-
value = {FullyRewritten, rewritten_to_inst_id};
418-
}
354+
CARBON_CHECK(
355+
value.state == BeingRewritten ||
356+
CompareFacetTypeConstraintValues(context, value.inst_id, inst_id));
357+
value = {FullyRewritten, inst_id};
419358
}
420359

421360
private:
@@ -573,7 +512,7 @@ class SubstImplWitnessAccessCallbacks : public SubstInstCallbacks {
573512
subst_inst_id)) {
574513
auto* rewrite_value = rewrite_values_->FindRef(context(), *access);
575514
CARBON_CHECK(rewrite_value);
576-
rewrite_values_->SetFullyRewritten(*rewrite_value, inst_id);
515+
rewrite_values_->SetFullyRewritten(context(), *rewrite_value, inst_id);
577516
}
578517
return inst_id;
579518
}
@@ -584,7 +523,8 @@ class SubstImplWitnessAccessCallbacks : public SubstInstCallbacks {
584523
subst_inst_id)) {
585524
auto* rewrite_value = rewrite_values_->FindRef(context(), *access);
586525
CARBON_CHECK(rewrite_value);
587-
rewrite_values_->SetFullyRewritten(*rewrite_value, orig_inst_id);
526+
rewrite_values_->SetFullyRewritten(context(), *rewrite_value,
527+
orig_inst_id);
588528
}
589529
return orig_inst_id;
590530
}
@@ -624,14 +564,9 @@ auto ResolveFacetTypeRewriteConstraints(
624564
return true;
625565
}
626566

627-
// Apply rewrite constraints to each other, so that for example:
628-
// `.X = Y and .Y = ()` becomes `.X = () and .Y = ()`.
629567
AccessRewriteValues rewrite_values;
568+
630569
for (auto& constraint : rewrites) {
631-
if (constraint.lhs_id == SemIR::ErrorInst::InstId ||
632-
constraint.rhs_id == SemIR::ErrorInst::InstId) {
633-
continue;
634-
}
635570
auto lhs_access =
636571
context.insts().TryGetAs<SemIR::ImplWitnessAccess>(constraint.lhs_id);
637572
if (!lhs_access) {
@@ -640,104 +575,87 @@ auto ResolveFacetTypeRewriteConstraints(
640575

641576
rewrite_values.InsertNotRewritten(context, *lhs_access, constraint.rhs_id);
642577
}
578+
643579
for (auto& constraint : rewrites) {
644-
if (constraint.lhs_id == SemIR::ErrorInst::InstId ||
645-
constraint.rhs_id == SemIR::ErrorInst::InstId) {
646-
continue;
647-
}
648580
auto lhs_access =
649581
context.insts().TryGetAs<SemIR::ImplWitnessAccess>(constraint.lhs_id);
650582
if (!lhs_access) {
651583
continue;
652584
}
653585

654586
auto* lhs_rewrite_value = rewrite_values.FindRef(context, *lhs_access);
587+
// Every LHS was added with InsertNotRewritten above.
655588
CARBON_CHECK(lhs_rewrite_value);
656589
rewrite_values.SetBeingRewritten(*lhs_rewrite_value);
657590

658591
auto replace_witness_callbacks =
659592
SubstImplWitnessAccessCallbacks(&context, loc_id, &rewrite_values);
660-
auto subst_inst_id =
593+
auto rhs_subst_inst_id =
661594
SubstInst(context, constraint.rhs_id, replace_witness_callbacks);
662-
constraint.rhs_id = subst_inst_id;
663-
if (constraint.rhs_id == SemIR::ErrorInst::InstId) {
595+
if (rhs_subst_inst_id == SemIR::ErrorInst::InstId) {
664596
return false;
665597
}
666598

667-
rewrite_values.SetFullyRewritten(*lhs_rewrite_value, subst_inst_id);
599+
if (lhs_rewrite_value->state == AccessRewriteValues::FullyRewritten &&
600+
!CompareFacetTypeConstraintValues(context, lhs_rewrite_value->inst_id,
601+
rhs_subst_inst_id)) {
602+
if (lhs_rewrite_value->inst_id != SemIR::ErrorInst::InstId) {
603+
CARBON_DIAGNOSTIC(AssociatedConstantWithDifferentValues, Error,
604+
"associated constant {0} given two different "
605+
"values {1} and {2}",
606+
InstIdAsConstant, InstIdAsConstant, InstIdAsConstant);
607+
// Use inst id ordering as a simple proxy for source ordering, to
608+
// try to name the values in the same order they appear in the facet
609+
// type.
610+
auto source_order1 =
611+
lhs_rewrite_value->inst_id.index < rhs_subst_inst_id.index
612+
? lhs_rewrite_value->inst_id
613+
: rhs_subst_inst_id;
614+
auto source_order2 =
615+
lhs_rewrite_value->inst_id.index >= rhs_subst_inst_id.index
616+
? lhs_rewrite_value->inst_id
617+
: rhs_subst_inst_id;
618+
// TODO: It would be nice to note the places where the values are
619+
// assigned but rewrite constraint instructions are from canonical
620+
// constant values, and have no locations. We'd need to store a
621+
// location along with them in the rewrite constraints.
622+
context.emitter().Emit(loc_id, AssociatedConstantWithDifferentValues,
623+
constraint.lhs_id, source_order1, source_order2);
624+
}
625+
return false;
626+
}
627+
628+
rewrite_values.SetFullyRewritten(context, *lhs_rewrite_value,
629+
rhs_subst_inst_id);
668630
}
669631

670-
// We sort the constraints so that we can find different values being written
671-
// to the same LHS by looking at consecutive rewrite constraints.
672-
//
673-
// It is important to dedupe so that we don't have redundant rewrite
674-
// constraints, as these lead to being diagnosed as a cycle. For example:
675-
// ```
676-
// (T:! Z where .X = .Y) where .X = .Y
677-
// ```
678-
// Here we drop one of the `.X = .Y` in the resulting facet type. If we don't,
679-
// then the `.X` in the outer facet type can be evaluated to `.Y` from the
680-
// inner facet type, resulting in `.Y = .Y` which is a cycle. By deduping, we
681-
// avoid any LHS of a rewrite constraint from being evaluated to the RHS of
682-
// a duplicate rewrite constraint.
683-
SortAndDedupeRewriteConstraints(context, rewrites);
684-
685-
for (size_t i = 0; i < rewrites.size() - 1; ++i) {
632+
// Rebuild the `rewrites` vector with resolved values for the RHS. Drop any
633+
// duplicate rewrites in the `rewrites` vector by walking through the
634+
// `rewrite_values` map and dropping the computed RHS value for each LHS the
635+
// first time we see it, and erasing the constraint from the vector if we see
636+
// the same LHS again.
637+
size_t keep_size = rewrites.size();
638+
for (size_t i = 0; i < keep_size;) {
686639
auto& constraint = rewrites[i];
687-
if (constraint.lhs_id == SemIR::ErrorInst::InstId ||
688-
constraint.rhs_id == SemIR::ErrorInst::InstId) {
689-
continue;
690-
}
691640

692641
auto lhs_access =
693642
context.insts().TryGetAs<SemIR::ImplWitnessAccess>(constraint.lhs_id);
694643
if (!lhs_access) {
644+
++i;
695645
continue;
696646
}
697647

698-
// This loop moves `i` to the last position with the same LHS value, so that
699-
// we don't diagnose more than once within the same contiguous range of
700-
// assignments to a single LHS value.
701-
for (; i < rewrites.size() - 1; ++i) {
702-
auto& next = rewrites[i + 1];
703-
auto next_lhs_access =
704-
context.insts().TryGetAs<SemIR::ImplWitnessAccess>(next.lhs_id);
705-
if (!next_lhs_access) {
706-
break;
707-
}
708-
709-
if (CompareFacetTypeConstraintValues(context, constraint.lhs_id,
710-
next.lhs_id) !=
711-
std::weak_ordering::equivalent) {
712-
break;
713-
}
714-
715-
if (constraint.rhs_id != SemIR::ErrorInst::InstId &&
716-
next.rhs_id != SemIR::ErrorInst::InstId) {
717-
CARBON_DIAGNOSTIC(
718-
AssociatedConstantWithDifferentValues, Error,
719-
"associated constant {0} given two different values {1} and {2}",
720-
InstIdAsConstant, InstIdAsConstant, InstIdAsConstant);
721-
// Use inst id ordering as a simple proxy for source ordering, to try
722-
// to name the values in the same order they appear in the facet type.
723-
auto source_order1 = constraint.rhs_id.index < next.rhs_id.index
724-
? constraint.rhs_id
725-
: next.rhs_id;
726-
auto source_order2 = constraint.rhs_id.index >= next.rhs_id.index
727-
? constraint.rhs_id
728-
: next.rhs_id;
729-
// TODO: It would be nice to note the places where the values are
730-
// assigned but rewrite constraint instructions are from canonical
731-
// constant values, and have no locations. We'd need to store a
732-
// location along with them in the rewrite constraints.
733-
context.emitter().Emit(loc_id, AssociatedConstantWithDifferentValues,
734-
constraint.lhs_id, source_order1, source_order2);
735-
}
736-
constraint.rhs_id = SemIR::ErrorInst::InstId;
737-
next.rhs_id = SemIR::ErrorInst::InstId;
738-
return false;
648+
auto& rewrite_value = *rewrite_values.FindRef(context, *lhs_access);
649+
auto rhs_id = std::exchange(rewrite_value.inst_id, SemIR::InstId::None);
650+
if (rhs_id == SemIR::InstId::None) {
651+
std::swap(rewrites[i], rewrites[keep_size - 1]);
652+
--keep_size;
653+
} else {
654+
rewrites[i].rhs_id = rhs_id;
655+
++i;
739656
}
740657
}
658+
rewrites.erase(rewrites.begin() + keep_size, rewrites.end());
741659

742660
return true;
743661
}

0 commit comments

Comments
 (0)