Skip to content

Commit 8cb932b

Browse files
zygoloidjonmeowCarbonInfraBot
authored
Substitute Self in associated function signatures before checking them against impls. (#3788)
Add a general substitution mechanism to support substituting symbolic bindings with their values throughout symbolic constants and, more specifically, types. This is done by decomposing the constant instruction into its operands, substituting into the operands, and then rebuilding the constant value by invoking the constant evaluator. --------- Co-authored-by: Jon Ross-Perkins <[email protected]> Co-authored-by: Carbon Infra Bot <[email protected]>
1 parent 2584399 commit 8cb932b

19 files changed

+755
-189
lines changed

toolchain/check/BUILD

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ cc_library(
142142
hdrs = ["function.h"],
143143
deps = [
144144
":context",
145+
":subst",
145146
"//common:check",
146147
"//toolchain/sem_ir:file",
147148
],
@@ -154,6 +155,7 @@ cc_library(
154155
deps = [
155156
":context",
156157
":function",
158+
":subst",
157159
"//common:check",
158160
"//toolchain/diagnostics:diagnostic_emitter",
159161
"//toolchain/sem_ir:file",
@@ -196,6 +198,23 @@ cc_library(
196198
name = "member_access",
197199
srcs = ["member_access.cpp"],
198200
hdrs = ["member_access.h"],
201+
deps = [
202+
":context",
203+
":subst",
204+
"//common:check",
205+
"//toolchain/diagnostics:diagnostic_emitter",
206+
"//toolchain/sem_ir:file",
207+
"//toolchain/sem_ir:ids",
208+
"//toolchain/sem_ir:inst",
209+
"//toolchain/sem_ir:inst_kind",
210+
"@llvm-project//llvm:Support",
211+
],
212+
)
213+
214+
cc_library(
215+
name = "subst",
216+
srcs = ["subst.cpp"],
217+
hdrs = ["subst.h"],
199218
deps = [
200219
":context",
201220
"//common:check",

toolchain/check/function.cpp

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include "toolchain/check/function.h"
66

7+
#include "toolchain/check/subst.h"
8+
79
namespace Carbon::Check {
810

911
// Returns true if there was an error in declaring the function, which will have
@@ -31,7 +33,8 @@ static auto CheckRedeclParam(Context& context,
3133
llvm::StringLiteral param_diag_label,
3234
int32_t param_index,
3335
SemIR::InstId new_param_ref_id,
34-
SemIR::InstId prev_param_ref_id) -> bool {
36+
SemIR::InstId prev_param_ref_id,
37+
Substitutions substitutions) -> bool {
3538
// TODO: Consider differentiating between type and name mistakes. For now,
3639
// taking the simpler approach because I also think we may want to refactor
3740
// params.
@@ -52,7 +55,8 @@ static auto CheckRedeclParam(Context& context,
5255
auto new_param_ref = context.insts().Get(new_param_ref_id);
5356
auto prev_param_ref = context.insts().Get(prev_param_ref_id);
5457
if (new_param_ref.kind() != prev_param_ref.kind() ||
55-
new_param_ref.type_id() != prev_param_ref.type_id()) {
58+
new_param_ref.type_id() !=
59+
SubstType(context, prev_param_ref.type_id(), substitutions)) {
5660
diagnose();
5761
return false;
5862
}
@@ -90,7 +94,8 @@ static auto CheckRedeclParams(Context& context, SemIR::InstId new_decl_id,
9094
SemIR::InstBlockId new_param_refs_id,
9195
SemIR::InstId prev_decl_id,
9296
SemIR::InstBlockId prev_param_refs_id,
93-
llvm::StringLiteral param_diag_label) -> bool {
97+
llvm::StringLiteral param_diag_label,
98+
Substitutions substitutions) -> bool {
9499
// This will often occur for empty params.
95100
if (new_param_refs_id == prev_param_refs_id) {
96101
return true;
@@ -116,7 +121,7 @@ static auto CheckRedeclParams(Context& context, SemIR::InstId new_decl_id,
116121
for (auto [index, new_param_ref_id, prev_param_ref_id] :
117122
llvm::enumerate(new_param_ref_ids, prev_param_ref_ids)) {
118123
if (!CheckRedeclParam(context, param_diag_label, index, new_param_ref_id,
119-
prev_param_ref_id)) {
124+
prev_param_ref_id, substitutions)) {
120125
return false;
121126
}
122127
}
@@ -125,21 +130,26 @@ static auto CheckRedeclParams(Context& context, SemIR::InstId new_decl_id,
125130

126131
// Returns false if the provided function declarations differ.
127132
static auto CheckRedecl(Context& context, const SemIR::Function& new_function,
128-
const SemIR::Function& prev_function) -> bool {
133+
const SemIR::Function& prev_function,
134+
Substitutions substitutions) -> bool {
129135
if (FunctionDeclHasError(context, new_function) ||
130136
FunctionDeclHasError(context, prev_function)) {
131137
return false;
132138
}
133-
if (!CheckRedeclParams(context, new_function.decl_id,
134-
new_function.implicit_param_refs_id,
135-
prev_function.decl_id,
136-
prev_function.implicit_param_refs_id, "implicit ") ||
139+
if (!CheckRedeclParams(
140+
context, new_function.decl_id, new_function.implicit_param_refs_id,
141+
prev_function.decl_id, prev_function.implicit_param_refs_id,
142+
"implicit ", substitutions) ||
137143
!CheckRedeclParams(context, new_function.decl_id,
138144
new_function.param_refs_id, prev_function.decl_id,
139-
prev_function.param_refs_id, "")) {
145+
prev_function.param_refs_id, "", substitutions)) {
140146
return false;
141147
}
142-
if (new_function.return_type_id != prev_function.return_type_id) {
148+
auto prev_return_type_id =
149+
prev_function.return_type_id.is_valid()
150+
? SubstType(context, prev_function.return_type_id, substitutions)
151+
: SemIR::TypeId::Invalid;
152+
if (new_function.return_type_id != prev_return_type_id) {
143153
CARBON_DIAGNOSTIC(
144154
FunctionRedeclReturnTypeDiffers, Error,
145155
"Function redeclaration differs because return type is `{0}`.",
@@ -154,12 +164,12 @@ static auto CheckRedecl(Context& context, const SemIR::Function& new_function,
154164
new_function.return_type_id)
155165
: context.emitter().Build(new_function.decl_id,
156166
FunctionRedeclReturnTypeDiffersNoReturn);
157-
if (prev_function.return_type_id.is_valid()) {
167+
if (prev_return_type_id.is_valid()) {
158168
CARBON_DIAGNOSTIC(FunctionRedeclReturnTypePrevious, Note,
159169
"Previously declared with return type `{0}`.",
160170
SemIR::TypeId);
161171
diag.Note(prev_function.decl_id, FunctionRedeclReturnTypePrevious,
162-
prev_function.return_type_id);
172+
prev_return_type_id);
163173
} else {
164174
CARBON_DIAGNOSTIC(FunctionRedeclReturnTypePreviousNoReturn, Note,
165175
"Previously declared with no return type.");
@@ -173,10 +183,12 @@ static auto CheckRedecl(Context& context, const SemIR::Function& new_function,
173183
return true;
174184
}
175185

176-
auto CheckFunctionRedecl(Context& context, SemIR::FunctionId new_function_id,
177-
SemIR::FunctionId prev_function_id) -> bool {
186+
auto CheckFunctionTypeMatches(Context& context,
187+
SemIR::FunctionId new_function_id,
188+
SemIR::FunctionId prev_function_id,
189+
Substitutions substitutions) -> bool {
178190
return CheckRedecl(context, context.functions().Get(new_function_id),
179-
context.functions().Get(prev_function_id));
191+
context.functions().Get(prev_function_id), substitutions);
180192
}
181193

182194
auto MergeFunctionRedecl(Context& context, Parse::NodeId node_id,
@@ -186,7 +198,7 @@ auto MergeFunctionRedecl(Context& context, Parse::NodeId node_id,
186198
auto& prev_function = context.functions().Get(prev_function_id);
187199

188200
// TODO: Disallow redeclarations within classes?
189-
if (!CheckRedecl(context, new_function, prev_function)) {
201+
if (!CheckRedecl(context, new_function, prev_function, {})) {
190202
return false;
191203
}
192204

toolchain/check/function.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,20 @@
66
#define CARBON_TOOLCHAIN_CHECK_FUNCTION_H_
77

88
#include "toolchain/check/context.h"
9+
#include "toolchain/check/subst.h"
910
#include "toolchain/sem_ir/function.h"
1011

1112
namespace Carbon::Check {
1213

13-
// Checks that `new_function_id` does not differ from `prev_function_id`.
14-
// Prints a suitable diagnostic and returns false if not.
15-
auto CheckFunctionRedecl(Context& context, SemIR::FunctionId new_function_id,
16-
SemIR::FunctionId prev_function_id) -> bool;
14+
// Checks that `new_function_id` has the same parameter types and return type as
15+
// `prev_function_id`, applying the specified set of substitutions to the
16+
// previous function. Prints a suitable diagnostic and returns false if not.
17+
// Note that this doesn't include the syntactic check that's performed for
18+
// redeclarations.
19+
auto CheckFunctionTypeMatches(Context& context,
20+
SemIR::FunctionId new_function_id,
21+
SemIR::FunctionId prev_function_id,
22+
Substitutions substitutions) -> bool;
1723

1824
// Tries to merge new_function into prev_function_id. Since new_function won't
1925
// have a definition even if one is upcoming, set is_definition to indicate the

toolchain/check/impl.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "toolchain/check/context.h"
88
#include "toolchain/check/function.h"
9+
#include "toolchain/check/subst.h"
910
#include "toolchain/diagnostics/diagnostic_emitter.h"
1011
#include "toolchain/sem_ir/ids.h"
1112
#include "toolchain/sem_ir/impl.h"
@@ -29,7 +30,7 @@ static auto NoteAssociatedFunction(Context& context,
2930
// `BuiltinError` if the function is not usable.
3031
static auto CheckAssociatedFunctionImplementation(
3132
Context& context, SemIR::FunctionId interface_function_id,
32-
SemIR::InstId impl_decl_id) -> SemIR::InstId {
33+
SemIR::InstId impl_decl_id, Substitutions substitutions) -> SemIR::InstId {
3334
auto impl_function_decl =
3435
context.insts().TryGetAs<SemIR::FunctionDecl>(impl_decl_id);
3536
if (!impl_function_decl) {
@@ -49,8 +50,8 @@ static auto CheckAssociatedFunctionImplementation(
4950
// before checking. Also, this should be a semantic check rather than a
5051
// syntactic one. The functions should be allowed to have different signatures
5152
// as long as we can synthesize a suitable thunk.
52-
if (!CheckFunctionRedecl(context, impl_function_decl->function_id,
53-
interface_function_id)) {
53+
if (!CheckFunctionTypeMatches(context, impl_function_decl->function_id,
54+
interface_function_id, substitutions)) {
5455
return SemIR::InstId::BuiltinError;
5556
}
5657
return impl_decl_id;
@@ -79,6 +80,11 @@ static auto BuildInterfaceWitness(
7980
context.inst_blocks().Get(interface.associated_entities_id);
8081
table.reserve(assoc_entities.size());
8182

83+
// Substitute `Self` with the impl's self type when associated functions.
84+
Substitution substitutions[1] = {
85+
{.bind_id = interface.self_param_id,
86+
.replacement_id = context.types().GetConstantId(impl.self_id)}};
87+
8288
for (auto decl_id : assoc_entities) {
8389
auto decl = context.insts().Get(decl_id);
8490
if (auto fn_decl = decl.TryAs<SemIR::FunctionDecl>()) {
@@ -88,7 +94,7 @@ static auto BuildInterfaceWitness(
8894
if (impl_decl_id.is_valid()) {
8995
used_decl_ids.push_back(impl_decl_id);
9096
table.push_back(CheckAssociatedFunctionImplementation(
91-
context, fn_decl->function_id, impl_decl_id));
97+
context, fn_decl->function_id, impl_decl_id, substitutions));
9298
} else {
9399
CARBON_DIAGNOSTIC(
94100
ImplMissingFunction, Error,

toolchain/check/member_access.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "llvm/ADT/STLExtras.h"
66
#include "toolchain/check/context.h"
77
#include "toolchain/check/convert.h"
8+
#include "toolchain/check/subst.h"
89
#include "toolchain/diagnostics/diagnostic_emitter.h"
910
#include "toolchain/sem_ir/inst.h"
1011
#include "toolchain/sem_ir/typed_insts.h"
@@ -151,6 +152,7 @@ static auto LookupInterfaceWitness(Context& context,
151152
static auto PerformImplLookup(Context& context, SemIR::ConstantId type_const_id,
152153
SemIR::AssociatedEntityType assoc_type,
153154
SemIR::InstId member_id) -> SemIR::InstId {
155+
auto& interface = context.interfaces().Get(assoc_type.interface_id);
154156
auto witness_id =
155157
LookupInterfaceWitness(context, type_const_id, assoc_type.interface_id);
156158
if (!witness_id.is_valid()) {
@@ -159,8 +161,7 @@ static auto PerformImplLookup(Context& context, SemIR::ConstantId type_const_id,
159161
"that does not implement that interface.",
160162
SemIR::NameId, std::string);
161163
context.emitter().Emit(
162-
member_id, MissingImplInMemberAccess,
163-
context.interfaces().Get(assoc_type.interface_id).name_id,
164+
member_id, MissingImplInMemberAccess, interface.name_id,
164165
context.sem_ir().StringifyTypeExpr(type_const_id.inst_id()));
165166
return SemIR::InstId::BuiltinError;
166167
}
@@ -180,9 +181,14 @@ static auto PerformImplLookup(Context& context, SemIR::ConstantId type_const_id,
180181
return SemIR::InstId::BuiltinError;
181182
}
182183

183-
// TODO: Substitute interface arguments and `Self` into `entity_type_id`.
184+
// Substitute into the type declared in the interface.
185+
Substitution substitutions[1] = {
186+
{.bind_id = interface.self_param_id, .replacement_id = type_const_id}};
187+
auto subst_type_id =
188+
SubstType(context, assoc_type.entity_type_id, substitutions);
189+
184190
return context.AddInst(SemIR::InterfaceWitnessAccess{
185-
assoc_type.entity_type_id, witness_id, assoc_entity->index});
191+
subst_type_id, witness_id, assoc_entity->index});
186192
}
187193

188194
// Performs a member name lookup into the specified scope, including performing

0 commit comments

Comments
 (0)