Skip to content

Commit 9d942f4

Browse files
geoffromerjonmeow
andauthored
Generate parameter pattern-match IR from pattern IR (#4388)
Also propagate the pattern IR along with the pattern-match IR, and use it where appropriate. Strictly speaking, some parts of the pattern-match IR are allocated eagerly, while traversing the pattern's parse tree, but they still aren't actually emitted until we traverse the associated pattern insts. --------- Co-authored-by: Jon Ross-Perkins <[email protected]>
1 parent 4a73b36 commit 9d942f4

File tree

347 files changed

+9145
-5799
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

347 files changed

+9145
-5799
lines changed

toolchain/check/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ cc_library(
3232
"modifiers.cpp",
3333
"name_component.cpp",
3434
"operator.cpp",
35+
"pattern_match.cpp",
3536
"return.cpp",
3637
"subst.cpp",
3738
],
@@ -57,6 +58,7 @@ cc_library(
5758
"name_component.h",
5859
"operator.h",
5960
"param_and_arg_refs_stack.h",
61+
"pattern_match.h",
6062
"pending_block.h",
6163
"return.h",
6264
"subst.h",

toolchain/check/call.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ static auto ResolveCalleeInCall(Context& context, SemIR::LocId loc_id,
6161
if (entity_generic_id.is_valid()) {
6262
specific_id = DeduceGenericCallArguments(
6363
context, loc_id, entity_generic_id, enclosing_specific_id,
64-
callee_info.implicit_param_refs_id, callee_info.param_refs_id, self_id,
65-
arg_ids);
64+
callee_info.implicit_param_patterns_id, callee_info.param_patterns_id,
65+
self_id, arg_ids);
6666
if (!specific_id.is_valid()) {
6767
return std::nullopt;
6868
}

toolchain/check/context.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,7 @@ class TypeCompleter {
921921
-> SemIR::ValueRepr {
922922
switch (builtin.builtin_inst_kind) {
923923
case SemIR::BuiltinInstKind::TypeType:
924+
case SemIR::BuiltinInstKind::AutoType:
924925
case SemIR::BuiltinInstKind::Error:
925926
case SemIR::BuiltinInstKind::Invalid:
926927
case SemIR::BuiltinInstKind::BoolType:

toolchain/check/convert.cpp

Lines changed: 30 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "toolchain/base/kind_switch.h"
1414
#include "toolchain/check/context.h"
1515
#include "toolchain/check/operator.h"
16+
#include "toolchain/check/pattern_match.h"
1617
#include "toolchain/sem_ir/copy_on_write_block.h"
1718
#include "toolchain/sem_ir/file.h"
1819
#include "toolchain/sem_ir/generic.h"
@@ -1138,9 +1139,8 @@ CARBON_DIAGNOSTIC(InCallToFunction, Note, "calling function declared here");
11381139
static auto ConvertSelf(Context& context, SemIR::LocId call_loc_id,
11391140
SemIRLoc callee_loc,
11401141
SemIR::SpecificId callee_specific_id,
1141-
std::optional<SemIR::AddrPattern> addr_pattern,
1142-
SemIR::InstId self_param_id, SemIR::Param self_param,
1143-
SemIR::InstId self_id) -> SemIR::InstId {
1142+
SemIR::InstId self_param_id, SemIR::InstId self_id)
1143+
-> SemIR::InstId {
11441144
if (!self_id.is_valid()) {
11451145
CARBON_DIAGNOSTIC(MissingObjectInMethodCall, Error,
11461146
"missing object argument in method call");
@@ -1151,6 +1151,7 @@ static auto ConvertSelf(Context& context, SemIR::LocId call_loc_id,
11511151
return SemIR::InstId::BuiltinError;
11521152
}
11531153

1154+
bool addr_pattern = context.insts().Is<SemIR::AddrPattern>(self_param_id);
11541155
DiagnosticAnnotationScope annotate_diagnostics(
11551156
&context.emitter(), [&](auto& builder) {
11561157
CARBON_DIAGNOSTIC(
@@ -1162,69 +1163,46 @@ static auto ConvertSelf(Context& context, SemIR::LocId call_loc_id,
11621163
: llvm::StringLiteral("self"));
11631164
});
11641165

1165-
// For `addr self`, take the address of the object argument.
1166-
auto self_or_addr_id = self_id;
1167-
if (addr_pattern) {
1168-
self_or_addr_id = ConvertToValueOrRefExpr(context, self_or_addr_id);
1169-
auto self = context.insts().Get(self_or_addr_id);
1170-
switch (SemIR::GetExprCategory(context.sem_ir(), self_id)) {
1171-
case SemIR::ExprCategory::Error:
1172-
case SemIR::ExprCategory::DurableRef:
1173-
case SemIR::ExprCategory::EphemeralRef:
1174-
break;
1175-
default:
1176-
CARBON_DIAGNOSTIC(AddrSelfIsNonRef, Error,
1177-
"`addr self` method cannot be invoked on a value");
1178-
context.emitter().Emit(TokenOnly(call_loc_id), AddrSelfIsNonRef);
1179-
return SemIR::InstId::BuiltinError;
1180-
}
1181-
auto loc_id = context.insts().GetLocId(self_or_addr_id);
1182-
self_or_addr_id = context.AddInst<SemIR::AddrOf>(
1183-
loc_id, {.type_id = context.GetPointerType(self.type_id()),
1184-
.lvalue_id = self_or_addr_id});
1185-
}
1186-
1187-
return ConvertToValueOfType(
1188-
context, call_loc_id, self_or_addr_id,
1189-
SemIR::GetTypeInSpecific(context.sem_ir(), callee_specific_id,
1190-
self_param.type_id));
1166+
return CallerPatternMatch(context, callee_specific_id, self_param_id,
1167+
self_id);
11911168
}
11921169

1170+
// TODO: consider moving this to pattern_match.h
11931171
auto ConvertCallArgs(Context& context, SemIR::LocId call_loc_id,
11941172
SemIR::InstId self_id,
11951173
llvm::ArrayRef<SemIR::InstId> arg_refs,
11961174
SemIR::InstId return_storage_id,
11971175
const CalleeParamsInfo& callee,
11981176
SemIR::SpecificId callee_specific_id)
11991177
-> SemIR::InstBlockId {
1200-
auto implicit_param_refs =
1201-
context.inst_blocks().GetOrEmpty(callee.implicit_param_refs_id);
1202-
auto param_refs = context.inst_blocks().GetOrEmpty(callee.param_refs_id);
1178+
auto implicit_param_patterns =
1179+
context.inst_blocks().GetOrEmpty(callee.implicit_param_patterns_id);
1180+
auto param_patterns =
1181+
context.inst_blocks().GetOrEmpty(callee.param_patterns_id);
12031182

12041183
// The caller should have ensured this callee has the right arity.
1205-
CARBON_CHECK(arg_refs.size() == param_refs.size());
1184+
CARBON_CHECK(arg_refs.size() == param_patterns.size());
12061185

12071186
// Start building a block to hold the converted arguments.
12081187
llvm::SmallVector<SemIR::InstId> args;
1209-
args.reserve(implicit_param_refs.size() + param_refs.size() +
1188+
args.reserve(implicit_param_patterns.size() + param_patterns.size() +
12101189
return_storage_id.is_valid());
12111190

12121191
// Check implicit parameters.
1213-
for (auto implicit_param_id : implicit_param_refs) {
1214-
auto addr_pattern =
1215-
context.insts().TryGetAs<SemIR::AddrPattern>(implicit_param_id);
1216-
auto param_info = SemIR::Function::GetParamFromParamRefId(
1192+
for (auto implicit_param_id : implicit_param_patterns) {
1193+
auto param_pattern_info = SemIR::Function::GetParamPatternInfoFromPatternId(
12171194
context.sem_ir(), implicit_param_id);
1218-
if (param_info.GetNameId(context.sem_ir()) == SemIR::NameId::SelfValue) {
1219-
auto converted_self_id = ConvertSelf(
1220-
context, call_loc_id, callee.callee_loc, callee_specific_id,
1221-
addr_pattern, param_info.inst_id, param_info.inst, self_id);
1195+
if (param_pattern_info.GetNameId(context.sem_ir()) ==
1196+
SemIR::NameId::SelfValue) {
1197+
auto converted_self_id =
1198+
ConvertSelf(context, call_loc_id, callee.callee_loc,
1199+
callee_specific_id, implicit_param_id, self_id);
12221200
if (converted_self_id == SemIR::InstId::BuiltinError) {
12231201
return SemIR::InstBlockId::Invalid;
12241202
}
12251203
args.push_back(converted_self_id);
12261204
} else {
1227-
CARBON_CHECK(!param_info.inst.runtime_index.is_valid(),
1205+
CARBON_CHECK(!param_pattern_info.inst.runtime_index.is_valid(),
12281206
"Unexpected implicit parameter passed at runtime");
12291207
}
12301208
}
@@ -1240,31 +1218,25 @@ auto ConvertCallArgs(Context& context, SemIR::LocId call_loc_id,
12401218
});
12411219

12421220
// Check type conversions per-element.
1243-
for (auto [i, arg_id, param_ref_id] : llvm::enumerate(arg_refs, param_refs)) {
1221+
for (auto [i, arg_id, param_pattern_id] :
1222+
llvm::enumerate(arg_refs, param_patterns)) {
12441223
diag_param_index = i;
12451224

1246-
// TODO: In general we need to perform pattern matching here to find the
1247-
// argument corresponding to each parameter.
1248-
auto param_info =
1249-
SemIR::Function::GetParamFromParamRefId(context.sem_ir(), param_ref_id);
1250-
if (!param_info.inst.runtime_index.is_valid()) {
1225+
auto runtime_index = SemIR::Function::GetParamPatternInfoFromPatternId(
1226+
context.sem_ir(), param_pattern_id)
1227+
.inst.runtime_index;
1228+
if (!runtime_index.is_valid()) {
12511229
// Not a runtime parameter: we don't pass an argument.
12521230
continue;
12531231
}
12541232

1255-
auto param_type_id = SemIR::GetTypeInSpecific(
1256-
context.sem_ir(), callee_specific_id,
1257-
context.insts().Get(param_info.inst_id).type_id());
1258-
// TODO: Convert to the proper expression category. For now, we assume
1259-
// parameters are all `let` bindings.
1260-
auto converted_arg_id =
1261-
ConvertToValueOfType(context, call_loc_id, arg_id, param_type_id);
1233+
auto converted_arg_id = CallerPatternMatch(context, callee_specific_id,
1234+
param_pattern_id, arg_id);
12621235
if (converted_arg_id == SemIR::InstId::BuiltinError) {
12631236
return SemIR::InstBlockId::Invalid;
12641237
}
12651238

1266-
CARBON_CHECK(static_cast<int32_t>(args.size()) ==
1267-
param_info.inst.runtime_index.index,
1239+
CARBON_CHECK(static_cast<int32_t>(args.size()) == runtime_index.index,
12681240
"Parameters not numbered in order.");
12691241
args.push_back(converted_arg_id);
12701242
}

toolchain/check/convert.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,18 @@ struct CalleeParamsInfo {
9797
explicit CalleeParamsInfo(const SemIR::EntityWithParamsBase& callee)
9898
: callee_loc(callee.latest_decl_id()),
9999
implicit_param_refs_id(callee.implicit_param_refs_id),
100-
param_refs_id(callee.param_refs_id) {}
100+
implicit_param_patterns_id(callee.implicit_param_patterns_id),
101+
param_refs_id(callee.param_refs_id),
102+
param_patterns_id(callee.param_patterns_id) {}
101103

102104
// The location of the callee to use in diagnostics.
103105
SemIRLoc callee_loc;
104106
// The implicit parameters of the callee.
105107
SemIR::InstBlockId implicit_param_refs_id;
108+
SemIR::InstBlockId implicit_param_patterns_id;
106109
// The explicit parameters of the callee.
107110
SemIR::InstBlockId param_refs_id;
111+
SemIR::InstBlockId param_patterns_id;
108112
};
109113

110114
// Implicitly converts a set of arguments to match the parameter types in a

toolchain/check/decl_name_stack.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,9 +377,9 @@ auto DeclNameStack::ResolveAsScope(const NameContext& name_context,
377377
return InvalidResult;
378378
}
379379

380-
auto new_params = DeclParams(name.name_loc_id, name.first_param_node_id,
381-
name.last_param_node_id, name.implicit_params_id,
382-
name.params_id);
380+
auto new_params = DeclParams(
381+
name.name_loc_id, name.first_param_node_id, name.last_param_node_id,
382+
name.implicit_param_patterns_id, name.param_patterns_id);
383383

384384
// Find the scope corresponding to the resolved instruction.
385385
// TODO: When diagnosing qualifiers on names, print a diagnostic that talks

toolchain/check/decl_name_stack.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ class DeclNameStack {
102102
.last_param_node_id = name.last_param_node_id,
103103
.pattern_block_id = name.pattern_block_id,
104104
.implicit_param_refs_id = name.implicit_params_id,
105+
.implicit_param_patterns_id = name.implicit_param_patterns_id,
105106
.param_refs_id = name.params_id,
107+
.param_patterns_id = name.param_patterns_id,
106108
.is_extern = is_extern,
107109
.extern_library_id = extern_library,
108110
.non_owning_decl_id =

toolchain/check/deduce.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,16 @@ auto DeductionContext::Deduce() -> bool {
282282
CARBON_KIND_SWITCH(param_inst) {
283283
// Deducing a symbolic binding from an argument with a constant value
284284
// deduces the binding as having that constant value.
285-
case CARBON_KIND(SemIR::BindSymbolicName bind): {
286-
auto& entity_name = context().entity_names().Get(bind.entity_name_id);
285+
case SemIR::InstKind::SymbolicBindingPattern:
286+
case SemIR::InstKind::BindSymbolicName: {
287+
auto entity_name_id = SemIR::EntityNameId::Invalid;
288+
if (auto bind = param_inst.TryAs<SemIR::SymbolicBindingPattern>()) {
289+
entity_name_id = bind->entity_name_id;
290+
} else {
291+
entity_name_id =
292+
param_inst.As<SemIR::BindSymbolicName>().entity_name_id;
293+
}
294+
auto& entity_name = context().entity_names().Get(entity_name_id);
287295
auto index = entity_name.bind_index;
288296
if (!index.is_valid() || index < first_deduced_index_) {
289297
break;

toolchain/check/eval.cpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1360,6 +1360,28 @@ static auto TryEvalInstInContext(EvalContext& eval_context,
13601360
case SemIR::ValueAsRef::Kind:
13611361
break;
13621362

1363+
case CARBON_KIND(SemIR::SymbolicBindingPattern bind): {
1364+
// TODO: disable constant evaluation of SymbolicBindingPattern once
1365+
// DeduceGenericCallArguments no longer needs implicit params to have
1366+
// constant values.
1367+
const auto& bind_name =
1368+
eval_context.entity_names().Get(bind.entity_name_id);
1369+
1370+
// If we know which specific we're evaluating within and this is an
1371+
// argument of that specific, its constant value is the corresponding
1372+
// argument value.
1373+
if (auto value =
1374+
eval_context.GetCompileTimeBindValue(bind_name.bind_index);
1375+
value.is_valid()) {
1376+
return value;
1377+
}
1378+
1379+
// The constant form of a symbolic binding is an idealized form of the
1380+
// original, with no equivalent value.
1381+
bind.entity_name_id =
1382+
eval_context.entity_names().MakeCanonical(bind.entity_name_id);
1383+
return MakeConstantResult(eval_context.context(), bind, Phase::Symbolic);
1384+
}
13631385
case CARBON_KIND(SemIR::BindSymbolicName bind): {
13641386
const auto& bind_name =
13651387
eval_context.entity_names().Get(bind.entity_name_id);
@@ -1394,6 +1416,9 @@ static auto TryEvalInstInContext(EvalContext& eval_context,
13941416
case CARBON_KIND(SemIR::NameRef typed_inst): {
13951417
return eval_context.GetConstantValue(typed_inst.value_id);
13961418
}
1419+
case CARBON_KIND(SemIR::ParamPattern param_pattern): {
1420+
return eval_context.GetConstantValue(param_pattern.subpattern_id);
1421+
}
13971422
case CARBON_KIND(SemIR::Converted typed_inst): {
13981423
return eval_context.GetConstantValue(typed_inst.result_id);
13991424
}
@@ -1466,7 +1491,6 @@ static auto TryEvalInstInContext(EvalContext& eval_context,
14661491
case SemIR::ReturnExpr::Kind:
14671492
case SemIR::Return::Kind:
14681493
case SemIR::StructLiteral::Kind:
1469-
case SemIR::SymbolicBindingPattern::Kind:
14701494
case SemIR::TupleLiteral::Kind:
14711495
case SemIR::VarStorage::Kind:
14721496
break;

toolchain/check/generic.cpp

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,12 @@ class RebuildGenericConstantInEvalBlockCallbacks final
108108
return true;
109109
}
110110

111+
if (auto pattern =
112+
context_.insts().TryGetAs<SemIR::SymbolicBindingPattern>(inst_id)) {
113+
inst_id = Rebuild(inst_id, *pattern);
114+
return true;
115+
}
116+
111117
return false;
112118
}
113119

@@ -426,39 +432,28 @@ auto ResolveSpecificDefinition(Context& context, SemIR::SpecificId specific_id)
426432
return true;
427433
}
428434

429-
// Replace the parameter with an invalid instruction so that we don't try
430-
// constructing a generic based on it. Note this is updating the param
431-
// refs block, not the actual params block, so will not be directly
432-
// reflected in SemIR output.
433-
static auto ReplaceInstructionWithError(Context& context,
434-
SemIR::InstId& inst_id) -> void {
435-
inst_id = context.AddInstInNoBlock<SemIR::Param>(
436-
context.insts().GetLocId(inst_id),
437-
{.type_id = SemIR::TypeId::Error,
438-
.name_id = SemIR::NameId::Base,
439-
.runtime_index = SemIR::RuntimeParamIndex::Invalid});
440-
}
441-
442-
auto RequireGenericParamsOnType(Context& context, SemIR::InstBlockId block_id)
443-
-> void {
444-
if (!block_id.is_valid() || block_id == SemIR::InstBlockId::Empty) {
435+
auto RequireGenericParamsOnType(Context& context,
436+
SemIR::InstBlockId pattern_block_id) -> void {
437+
if (!pattern_block_id.is_valid() ||
438+
pattern_block_id == SemIR::InstBlockId::Empty) {
445439
return;
446440
}
447-
for (auto& inst_id : context.inst_blocks().Get(block_id)) {
448-
auto param_info =
449-
SemIR::Function::GetParamFromParamRefId(context.sem_ir(), inst_id);
450-
if (param_info.GetNameId(context.sem_ir()) == SemIR::NameId::SelfValue) {
441+
for (auto& inst_id : context.inst_blocks().Get(pattern_block_id)) {
442+
auto name_id = SemIR::Function::GetParamPatternInfoFromPatternId(
443+
context.sem_ir(), inst_id)
444+
.GetNameId(context.sem_ir());
445+
if (name_id == SemIR::NameId::SelfValue) {
451446
CARBON_DIAGNOSTIC(SelfParameterNotAllowed, Error,
452447
"`self` parameter only allowed on functions");
453448
context.emitter().Emit(inst_id, SelfParameterNotAllowed);
454449

455-
ReplaceInstructionWithError(context, inst_id);
450+
inst_id = SemIR::InstId::BuiltinError;
456451
} else if (!context.constant_values().Get(inst_id).is_constant()) {
457452
CARBON_DIAGNOSTIC(GenericParamMustBeConstant, Error,
458453
"parameters of generic types must be constant");
459454
context.emitter().Emit(inst_id, GenericParamMustBeConstant);
460455

461-
ReplaceInstructionWithError(context, inst_id);
456+
inst_id = SemIR::InstId::BuiltinError;
462457
}
463458
}
464459
}
@@ -479,7 +474,7 @@ auto RequireGenericOrSelfImplicitFunctionParams(Context& context,
479474
"implicit parameters of functions must be constant or `self`");
480475
context.emitter().Emit(inst_id, ImplictParamMustBeConstant);
481476

482-
ReplaceInstructionWithError(context, inst_id);
477+
inst_id = SemIR::InstId::BuiltinError;
483478
}
484479
}
485480
}

0 commit comments

Comments
 (0)