Skip to content

Commit 2e22733

Browse files
authored
Use canonical constant values as the keys for ImplWitnessAccess in AccessRewriteValues (#5912)
Instead of a bespoke structure based on the `EntityName` in the `ImplWitnessAccess`' `.Self` type, use the constant value of the `ImplWitnessAccess` as the map key in `AccessRewriteValues`. This is okay after #5883 makes the `ImplWitnessAccess` to `.Self` canonically the same regardless of how it's constructed with nested `where` expressions. Introduce `KnownInstId` which tracks in the type system that an `InstId` is known to refer to a specific typed inst structure. This avoids writing CHECKs and comments and allows compiler enforcement.
1 parent f6b9826 commit 2e22733

File tree

3 files changed

+106
-104
lines changed

3 files changed

+106
-104
lines changed

toolchain/check/facet_type.cpp

Lines changed: 34 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -263,38 +263,6 @@ auto AllocateFacetTypeImplWitness(Context& context,
263263
context.inst_blocks().ReplacePlaceholder(witness_id, empty_table);
264264
}
265265

266-
namespace {
267-
// TODO: This class should go away, and we should just use the constant value of
268-
// the ImplWitnessAccess as a key in AccessRewriteValues, but that requires
269-
// changing its API to work with InstId instead of ImplWitnessAccess.
270-
struct FacetTypeConstraintValue {
271-
SemIR::EntityNameId entity_name_id;
272-
SemIR::ElementIndex access_index;
273-
SemIR::SpecificInterfaceId specific_interface_id;
274-
275-
friend auto operator==(const FacetTypeConstraintValue& lhs,
276-
const FacetTypeConstraintValue& rhs) -> bool = default;
277-
};
278-
} // namespace
279-
280-
static auto GetFacetTypeConstraintValue(Context& context,
281-
SemIR::ImplWitnessAccess access)
282-
-> std::optional<FacetTypeConstraintValue> {
283-
auto lookup =
284-
context.insts().TryGetAs<SemIR::LookupImplWitness>(access.witness_id);
285-
if (lookup) {
286-
auto self = context.insts().TryGetAs<SemIR::BindSymbolicName>(
287-
context.constant_values().GetConstantInstId(
288-
lookup->query_self_inst_id));
289-
if (self) {
290-
return {{.entity_name_id = self->entity_name_id,
291-
.access_index = access.index,
292-
.specific_interface_id = lookup->query_specific_interface_id}};
293-
}
294-
}
295-
return std::nullopt;
296-
}
297-
298266
// A mapping of each associated constant (represented as `ImplWitnessAccess`) to
299267
// its value (represented as an `InstId`). Used to track rewrite constraints,
300268
// with the LHS mapping to the resolved value of the RHS.
@@ -310,24 +278,24 @@ class AccessRewriteValues {
310278
SemIR::InstId inst_id;
311279
};
312280

313-
auto InsertNotRewritten(Context& context, SemIR::ImplWitnessAccess access,
314-
SemIR::InstId inst_id) -> void {
315-
map_.insert({*GetKey(context, access), {NotRewritten, inst_id}});
281+
auto InsertNotRewritten(
282+
Context& context, SemIR::KnownInstId<SemIR::ImplWitnessAccess> access_id,
283+
SemIR::InstId inst_id) -> void {
284+
map_.Insert(context.constant_values().Get(access_id),
285+
{NotRewritten, inst_id});
316286
}
317287

318288
// Finds and returns a pointer into the cache for a given ImplWitnessAccess.
319289
// The pointer will be invalidated by mutating the cache. Returns `nullptr`
320290
// if `access` is not found.
321-
auto FindRef(Context& context, SemIR::ImplWitnessAccess access) -> Value* {
322-
auto key = GetKey(context, access);
323-
if (!key) {
324-
return nullptr;
325-
}
326-
auto it = map_.find(*key);
327-
if (it == map_.end()) {
291+
auto FindRef(Context& context,
292+
SemIR::KnownInstId<SemIR::ImplWitnessAccess> access_id)
293+
-> Value* {
294+
auto result = map_.Lookup(context.constant_values().Get(access_id));
295+
if (!result) {
328296
return nullptr;
329297
}
330-
return &it->second;
298+
return &result.value();
331299
}
332300

333301
auto SetBeingRewritten(Value& value) -> void {
@@ -346,54 +314,13 @@ class AccessRewriteValues {
346314
}
347315

348316
private:
349-
using Key = FacetTypeConstraintValue;
350-
struct KeyInfo {
351-
static auto getEmptyKey() -> Key {
352-
return {
353-
.entity_name_id = SemIR::EntityNameId::None,
354-
.access_index = SemIR::ElementIndex(-1),
355-
.specific_interface_id = SemIR::SpecificInterfaceId::None,
356-
};
357-
}
358-
static auto getTombstoneKey() -> Key {
359-
return {
360-
.entity_name_id = SemIR::EntityNameId::None,
361-
.access_index = SemIR::ElementIndex(-2),
362-
.specific_interface_id = SemIR::SpecificInterfaceId::None,
363-
};
364-
}
365-
static auto getHashValue(Key key) -> unsigned {
366-
// This hash produces the same value if two ImplWitnessAccess are
367-
// pointing to the same associated constant, even if they are different
368-
// instruction ids.
369-
//
370-
// TODO: This truncates the high bits of the hash code which does not
371-
// make for a good hash function.
372-
return static_cast<unsigned>(static_cast<uint64_t>(HashValue(key)));
373-
}
374-
static auto isEqual(Key lhs, Key rhs) -> bool {
375-
// This comparison is true if the two ImplWitnessAccess are pointing to
376-
// the same associated constant, even if they are different instruction
377-
// ids.
378-
return lhs == rhs;
379-
}
380-
};
381-
382-
// Returns a key for the `access` to an associated context if the access is
383-
// through a facet value. If the access it through another `ImplWitnessAccess`
384-
// then no key is able to be made.
385-
auto GetKey(Context& context, SemIR::ImplWitnessAccess access)
386-
-> std::optional<Key> {
387-
return GetFacetTypeConstraintValue(context, access);
388-
}
389-
390317
// Try avoid heap allocations in the common case where there are a small
391318
// number of rewrite rules referring to each other by keeping up to 16 on
392319
// the stack.
393320
//
394321
// TODO: Revisit if 16 is an appropriate number when we can measure how deep
395322
// rewrite constraint chains go in practice.
396-
llvm::SmallDenseMap<Key, Value, 16, KeyInfo> map_;
323+
Map<SemIR::ConstantId, Value, 16> map_;
397324
};
398325

399326
// To be used for substituting into the RHS of a rewrite constraint.
@@ -437,7 +364,7 @@ class SubstImplWitnessAccessCallbacks : public SubstInstCallbacks {
437364

438365
auto Subst(SemIR::InstId& rhs_inst_id) -> SubstResult override {
439366
auto rhs_access =
440-
context().insts().TryGetAs<SemIR::ImplWitnessAccess>(rhs_inst_id);
367+
context().insts().TryGetAsWithId<SemIR::ImplWitnessAccess>(rhs_inst_id);
441368
if (!rhs_access) {
442369
// We only want to substitute ImplWitnessAccesses written directly on the
443370
// RHS of the rewrite constraint, not when they are nested inside facet
@@ -471,15 +398,16 @@ class SubstImplWitnessAccessCallbacks : public SubstInstCallbacks {
471398
// access needs to be resolved to a facet value first. If it can't be
472399
// resolved then the outer one can not be either.
473400
if (auto lookup = context().insts().TryGetAs<SemIR::LookupImplWitness>(
474-
rhs_access->witness_id)) {
401+
rhs_access->inst.witness_id)) {
475402
if (context().insts().Is<SemIR::ImplWitnessAccess>(
476403
lookup->query_self_inst_id)) {
477404
substs_in_progress_.push_back(rhs_inst_id);
478405
return SubstResult::SubstOperandsAndRetry;
479406
}
480407
}
481408

482-
auto* rewrite_value = rewrite_values_->FindRef(context(), *rhs_access);
409+
auto* rewrite_value =
410+
rewrite_values_->FindRef(context(), rhs_access->inst_id);
483411
if (!rewrite_value) {
484412
// The RHS refers to an associated constant for which there is no rewrite
485413
// rule.
@@ -521,9 +449,11 @@ class SubstImplWitnessAccessCallbacks : public SubstInstCallbacks {
521449
-> SemIR::InstId override {
522450
auto inst_id = RebuildNewInst(loc_id_, new_inst);
523451
auto subst_inst_id = substs_in_progress_.pop_back_val();
524-
if (auto access = context().insts().TryGetAs<SemIR::ImplWitnessAccess>(
525-
subst_inst_id)) {
526-
if (auto* rewrite_value = rewrite_values_->FindRef(context(), *access)) {
452+
if (auto access =
453+
context().insts().TryGetAsWithId<SemIR::ImplWitnessAccess>(
454+
subst_inst_id)) {
455+
if (auto* rewrite_value =
456+
rewrite_values_->FindRef(context(), access->inst_id)) {
527457
rewrite_values_->SetFullyRewritten(context(), *rewrite_value, inst_id);
528458
}
529459
}
@@ -532,9 +462,11 @@ class SubstImplWitnessAccessCallbacks : public SubstInstCallbacks {
532462

533463
auto ReuseUnchanged(SemIR::InstId orig_inst_id) -> SemIR::InstId override {
534464
auto subst_inst_id = substs_in_progress_.pop_back_val();
535-
if (auto access = context().insts().TryGetAs<SemIR::ImplWitnessAccess>(
536-
subst_inst_id)) {
537-
if (auto* rewrite_value = rewrite_values_->FindRef(context(), *access)) {
465+
if (auto access =
466+
context().insts().TryGetAsWithId<SemIR::ImplWitnessAccess>(
467+
subst_inst_id)) {
468+
if (auto* rewrite_value =
469+
rewrite_values_->FindRef(context(), access->inst_id)) {
538470
rewrite_values_->SetFullyRewritten(context(), *rewrite_value,
539471
orig_inst_id);
540472
}
@@ -580,23 +512,25 @@ auto ResolveFacetTypeRewriteConstraints(
580512
AccessRewriteValues rewrite_values;
581513

582514
for (auto& constraint : rewrites) {
583-
auto lhs_access = context.insts().TryGetAs<SemIR::ImplWitnessAccess>(
515+
auto lhs_access = context.insts().TryGetAsWithId<SemIR::ImplWitnessAccess>(
584516
GetImplWitnessAccessWithoutSubstitution(context, constraint.lhs_id));
585517
if (!lhs_access) {
586518
continue;
587519
}
588520

589-
rewrite_values.InsertNotRewritten(context, *lhs_access, constraint.rhs_id);
521+
rewrite_values.InsertNotRewritten(context, lhs_access->inst_id,
522+
constraint.rhs_id);
590523
}
591524

592525
for (auto& constraint : rewrites) {
593-
auto lhs_access = context.insts().TryGetAs<SemIR::ImplWitnessAccess>(
526+
auto lhs_access = context.insts().TryGetAsWithId<SemIR::ImplWitnessAccess>(
594527
GetImplWitnessAccessWithoutSubstitution(context, constraint.lhs_id));
595528
if (!lhs_access) {
596529
continue;
597530
}
598531

599-
auto* lhs_rewrite_value = rewrite_values.FindRef(context, *lhs_access);
532+
auto* lhs_rewrite_value =
533+
rewrite_values.FindRef(context, lhs_access->inst_id);
600534
// Every LHS was added with InsertNotRewritten above.
601535
CARBON_CHECK(lhs_rewrite_value);
602536
rewrite_values.SetBeingRewritten(*lhs_rewrite_value);
@@ -658,14 +592,14 @@ auto ResolveFacetTypeRewriteConstraints(
658592
for (size_t i = 0; i < keep_size;) {
659593
auto& constraint = rewrites[i];
660594

661-
auto lhs_access = context.insts().TryGetAs<SemIR::ImplWitnessAccess>(
595+
auto lhs_access = context.insts().TryGetAsWithId<SemIR::ImplWitnessAccess>(
662596
GetImplWitnessAccessWithoutSubstitution(context, constraint.lhs_id));
663597
if (!lhs_access) {
664598
++i;
665599
continue;
666600
}
667601

668-
auto& rewrite_value = *rewrite_values.FindRef(context, *lhs_access);
602+
auto& rewrite_value = *rewrite_values.FindRef(context, lhs_access->inst_id);
669603
auto rhs_id = std::exchange(rewrite_value.inst_id, SemIR::InstId::None);
670604
if (rhs_id == SemIR::InstId::None) {
671605
std::swap(rewrites[i], rewrites[keep_size - 1]);

toolchain/sem_ir/ids.h

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,10 @@ struct InstId : public IdBase<InstId> {
3737

3838
constexpr InstId InstId::InitTombstone = InstId(NoneIndex - 1);
3939

40-
// An InstId whose value is a type. The fact it's a type is CHECKed on
41-
// construction, and this allows that check to be represented in the type
42-
// system.
40+
// An InstId whose value is a type. The fact it's a type must be validated
41+
// before construction, and this allows that validation to be represented in the
42+
// type system.
4343
struct TypeInstId : public InstId {
44-
static constexpr llvm::StringLiteral Label = "type_inst";
4544
static const TypeInstId None;
4645

4746
using InstId::InstId;
@@ -58,6 +57,33 @@ struct TypeInstId : public InstId {
5857

5958
constexpr TypeInstId TypeInstId::None = TypeInstId::UnsafeMake(InstId::None);
6059

60+
// An InstId whose type is known to be T. The fact it's a type must be validated
61+
// before construction, and this allows that validation to be represented in the
62+
// type system.
63+
//
64+
// Unlike TypeInstId, this type can *not* be an operand in instructions, since
65+
// being a template prevents it from being used in non-generic contexts such as
66+
// switches.
67+
template <class T>
68+
struct KnownInstId : public InstId {
69+
static const KnownInstId None;
70+
71+
using InstId::InstId;
72+
73+
static constexpr auto UnsafeMake(InstId id) -> KnownInstId {
74+
return KnownInstId(UnsafeCtor(), id);
75+
}
76+
77+
private:
78+
struct UnsafeCtor {};
79+
explicit constexpr KnownInstId(UnsafeCtor /*unsafe*/, InstId id)
80+
: InstId(id) {}
81+
};
82+
83+
template <class T>
84+
constexpr KnownInstId<T> KnownInstId<T>::None =
85+
KnownInstId<T>::UnsafeMake(InstId::None);
86+
6187
// An ID of an instruction that is referenced absolutely by another instruction.
6288
// This should only be used as the type of a field within a typed instruction
6389
// class.

toolchain/sem_ir/inst.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "toolchain/base/int.h"
1919
#include "toolchain/base/value_store.h"
2020
#include "toolchain/sem_ir/id_kind.h"
21+
#include "toolchain/sem_ir/ids.h"
2122
#include "toolchain/sem_ir/inst_kind.h"
2223
#include "toolchain/sem_ir/singleton_insts.h"
2324
#include "toolchain/sem_ir/typed_insts.h"
@@ -465,6 +466,13 @@ class InstStore {
465466
return result;
466467
}
467468

469+
// Returns the requested instruction, which is known to have the specified
470+
// type.
471+
template <typename InstT>
472+
auto Get(KnownInstId<InstT> inst_id) const -> InstT {
473+
return Get(static_cast<InstId>(inst_id)).As<InstT>();
474+
}
475+
468476
// Returns the requested instruction, preserving its attached type.
469477
auto GetWithAttachedType(InstId inst_id) const -> Inst {
470478
return values_.Get(inst_id);
@@ -500,6 +508,10 @@ class InstStore {
500508
return Get(inst_id).TryAs<InstT>();
501509
}
502510

511+
// Use `Get()` when the instruction type is known.
512+
template <typename InstT, typename KnownInstT>
513+
auto TryGetAs(KnownInstId<KnownInstT> inst_id) const = delete;
514+
503515
// Returns the requested instruction as the specified type, if it is valid and
504516
// of that type. Otherwise returns nullopt.
505517
template <typename InstT>
@@ -510,6 +522,36 @@ class InstStore {
510522
return TryGetAs<InstT>(inst_id);
511523
}
512524

525+
template <class InstT>
526+
struct GetAsWithIdResult {
527+
SemIR::KnownInstId<InstT> inst_id;
528+
InstT inst;
529+
};
530+
531+
// Returns the requested instruction, which is known to have the specified
532+
// type, along with the original `InstId`, encoding the work of checking its
533+
// type in a `KnownInstId`.
534+
template <typename InstT>
535+
auto GetAsWithId(InstId inst_id) const -> GetAsWithIdResult<InstT> {
536+
auto inst = GetAs<InstT>(inst_id);
537+
return {.inst_id = SemIR::KnownInstId<InstT>::UnsafeMake(inst_id),
538+
.inst = inst};
539+
}
540+
541+
// Returns the requested instruction, if it is of that type, along with the
542+
// original `InstId`, encoding the work of checking its type in a
543+
// `KnownInstId`.
544+
template <typename InstT>
545+
auto TryGetAsWithId(InstId inst_id) const
546+
-> std::optional<GetAsWithIdResult<InstT>> {
547+
auto inst = TryGetAs<InstT>(inst_id);
548+
if (!inst) {
549+
return std::nullopt;
550+
}
551+
return {{.inst_id = SemIR::KnownInstId<InstT>::UnsafeMake(inst_id),
552+
.inst = *inst}};
553+
}
554+
513555
// Attempts to convert the given instruction to the type that contains
514556
// `member`. If it can be converted, the instruction ID and instruction are
515557
// replaced by the unwrapped value of that member, and the converted wrapper

0 commit comments

Comments
 (0)