diff --git a/core/prelude/types/string.carbon b/core/prelude/types/string.carbon index 2f70a23722490..c6b5e9ecda946 100644 --- a/core/prelude/types/string.carbon +++ b/core/prelude/types/string.carbon @@ -8,6 +8,14 @@ import library "prelude/copy"; import library "prelude/destroy"; import library "prelude/types/char"; import library "prelude/types/uint"; +import library "prelude/operators/index"; +import library "prelude/types/int"; +import library "prelude/operators/as"; + +class String; + +// Forward declaration for builtin function +fn StringAt(s: String, index: i32) -> Char; class String { fn Size[self: Self]() -> u64 { return self.size; } @@ -16,8 +24,15 @@ class String { fn Op[self: Self]() -> Self { return {.ptr = self.ptr, .size = self.size}; } } + impl forall [T:! ImplicitAs(i32)] as IndexWith(T) where .ElementType = Char { + fn At[self: Self](subscript: T) -> Char { + return StringAt(self, subscript); + } + } // TODO: This should be an array iterator. private var ptr: Char*; // TODO: This should be a word-sized integer. private var size: u64; } + +fn StringAt(s: String, index: i32) -> Char = "string.at"; diff --git a/toolchain/check/eval.cpp b/toolchain/check/eval.cpp index 5dc529f51cf9f..f4d15e3927a43 100644 --- a/toolchain/check/eval.cpp +++ b/toolchain/check/eval.cpp @@ -1687,6 +1687,39 @@ static auto MakeConstantForBuiltinCall(EvalContext& eval_context, return context.constant_values().Get(arg_ids[0]); } + case SemIR::BuiltinFunctionKind::StringAt: { + Phase phase = Phase::Concrete; + auto str_id = GetConstantValue(eval_context, arg_ids[0], &phase); + auto index_id = GetConstantValue(eval_context, arg_ids[1], &phase); + + if (phase != Phase::Concrete) { + return MakeNonConstantResult(phase); + } + + auto str_struct = eval_context.insts().GetAs(str_id); + auto elements = eval_context.inst_blocks().Get(str_struct.elements_id); + CARBON_CHECK(elements.size() == 2, "String struct should have 2 fields."); + + auto ptr_const_id = eval_context.constant_values().Get(elements[0]); + auto string_literal = eval_context.insts().GetAs( + eval_context.constant_values().GetInstId(ptr_const_id)); + + auto string_value = eval_context.sem_ir().string_literal_values().Get( + string_literal.string_literal_id); + + auto index_inst = eval_context.insts().GetAs(index_id); + const auto& index_val = eval_context.ints().Get(index_inst.int_id); + + auto char_value = + static_cast(string_value[index_val.getZExtValue()]); + + auto int_id = eval_context.ints().Add( + llvm::APSInt(llvm::APInt(32, char_value), /*isUnsigned=*/false)); + return MakeConstantResult( + eval_context.context(), + SemIR::IntValue{.type_id = call.type_id, .int_id = int_id}, phase); + } + case SemIR::BuiltinFunctionKind::PrintChar: case SemIR::BuiltinFunctionKind::PrintInt: case SemIR::BuiltinFunctionKind::ReadChar: diff --git a/toolchain/check/handle_index.cpp b/toolchain/check/handle_index.cpp index 17a502403c6d5..60e23f9430136 100644 --- a/toolchain/check/handle_index.cpp +++ b/toolchain/check/handle_index.cpp @@ -26,6 +26,68 @@ auto HandleParseNode(Context& /*context*/, Parse::IndexExprStartId /*node_id*/) return true; } +// Performs bounds checking for string indexing when the index is a constant. +static auto CheckStringIndexBounds(Context& context, + SemIR::InstId operand_inst_id, + SemIR::InstId index_inst_id, + const llvm::APInt& index_int) -> void { + if (index_int.isNegative()) { + CARBON_DIAGNOSTIC(ArrayIndexNegative, Error, "index `{0}` is negative.", + TypedInt); + context.emitter().Emit( + SemIR::LocId(index_inst_id), ArrayIndexNegative, + {.type = context.insts().Get(index_inst_id).type_id(), + .value = index_int}); + return; + } + + auto operand_const_id = context.constant_values().Get(operand_inst_id); + if (!operand_const_id.is_constant()) { + return; + } + + auto operand_const_inst_id = + context.constant_values().GetInstId(operand_const_id); + auto str_struct = + context.insts().TryGetAs(operand_const_inst_id); + if (!str_struct) { + return; + } + + auto elements = context.inst_blocks().Get(str_struct->elements_id); + CARBON_CHECK(elements.size() == 2, "String struct should have 2 fields."); + + auto ptr_const_id = context.constant_values().Get(elements[0]); + auto ptr_inst_id = context.constant_values().GetInstId(ptr_const_id); + auto string_literal = + context.insts().TryGetAs(ptr_inst_id); + if (!string_literal) { + return; + } + + auto string_value = context.sem_ir().string_literal_values().Get( + string_literal->string_literal_id); + if (index_int.getActiveBits() > 64 || + index_int.getZExtValue() >= string_value.size()) { + CARBON_DIAGNOSTIC(StringAtIndexOutOfBounds, Error, + "string index `{0}` is past the end of the string.", + TypedInt); + context.emitter().Emit( + SemIR::LocId(index_inst_id), StringAtIndexOutOfBounds, + {.type = context.insts().Get(index_inst_id).type_id(), + .value = index_int}); + } +} + +// Checks if the given ClassType is the String class. +static auto IsStringType(Context& context, SemIR::ClassType class_type) + -> bool { + auto& class_info = context.classes().Get(class_type.class_id); + auto identifier_id = class_info.name_id.AsIdentifierId(); + return identifier_id.has_value() && + context.identifiers().Get(identifier_id) == "String"; +} + // Performs an index with base expression `operand_inst_id` and // `operand_type_id` for types that are not an array. This checks if // the base expression implements the `IndexWith` interface; if so, uses the @@ -84,6 +146,27 @@ auto HandleParseNode(Context& context, Parse::IndexExprId node_id) -> bool { return true; } + case CARBON_KIND(SemIR::ClassType class_type): { + if (IsStringType(context, class_type)) { + auto index_const_id = context.constant_values().Get(index_inst_id); + if (index_const_id.is_constant()) { + auto index_const_inst_id = + context.constant_values().GetInstId(index_const_id); + if (auto index_val = context.insts().TryGetAs( + index_const_inst_id)) { + const auto& index_int = context.ints().Get(index_val->int_id); + CheckStringIndexBounds(context, operand_inst_id, index_inst_id, + index_int); + } + } + } + + auto elem_id = + PerformIndexWith(context, node_id, operand_inst_id, index_inst_id); + context.node_stack().Push(node_id, elem_id); + return true; + } + default: { auto elem_id = PerformIndexWith(context, node_id, operand_inst_id, index_inst_id); diff --git a/toolchain/check/testdata/operators/overloaded/string_indexing.carbon b/toolchain/check/testdata/operators/overloaded/string_indexing.carbon new file mode 100644 index 0000000000000..6815741430552 --- /dev/null +++ b/toolchain/check/testdata/operators/overloaded/string_indexing.carbon @@ -0,0 +1,26 @@ +// Part of the Carbon Language project, under the Apache License v2.0 with LLVM +// Exceptions. See /LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// INCLUDE-FILE: toolchain/testing/testdata/min_prelude/full.carbon +// AUTOUPDATE +// TIP: To test this file alone, run: +// TIP: bazel test //toolchain/testing:file_test --test_arg=--file_tests=toolchain/check/testdata/operators/overloaded/string_indexing.carbon +// TIP: To dump output, run: +// TIP: bazel run //toolchain/testing:file_test -- --dump_output --file_tests=toolchain/check/testdata/operators/overloaded/string_indexing.carbon + + +// --- test_string_indexing.carbon + + +import Core library "io"; +import Core library "range"; + +fn PrintStr(msg: str) { + for (i: i32 in Core.Range(msg.Size() as i32)) { + Core.PrintChar(msg[i]); + } +} + +fn Run() { + PrintStr("Hello World!\n"); +} diff --git a/toolchain/check/testdata/operators/overloaded/string_indexing_negative.carbon b/toolchain/check/testdata/operators/overloaded/string_indexing_negative.carbon new file mode 100644 index 0000000000000..2f3beee483e29 --- /dev/null +++ b/toolchain/check/testdata/operators/overloaded/string_indexing_negative.carbon @@ -0,0 +1,21 @@ +// Part of the Carbon Language project, under the Apache License v2.0 with LLVM +// Exceptions. See /LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// INCLUDE-FILE: toolchain/testing/testdata/min_prelude/full.carbon +// AUTOUPDATE +// TIP: To test this file alone, run: +// TIP: bazel test //toolchain/testing:file_test --test_arg=--file_tests=toolchain/check/testdata/operators/overloaded/string_indexing_negative.carbon +// TIP: To dump output, run: +// TIP: bazel run //toolchain/testing:file_test -- --dump_output --file_tests=toolchain/check/testdata/operators/overloaded/string_indexing_negative.carbon + + +// --- fail_negative_index.carbon + +fn TestNegativeIndex() { + let test_str: str = "Test"; + // CHECK:STDERR: fail_negative_index.carbon:[[@LINE+4]]:31: error: index `-1` is negative. [ArrayIndexNegative] + // CHECK:STDERR: let c: Core.Char = test_str[-1]; + // CHECK:STDERR: ^~ + // CHECK:STDERR: + let c: Core.Char = test_str[-1]; +} diff --git a/toolchain/check/testdata/operators/overloaded/string_indexing_out_of_bounds.carbon b/toolchain/check/testdata/operators/overloaded/string_indexing_out_of_bounds.carbon new file mode 100644 index 0000000000000..22c3d0f1d22cd --- /dev/null +++ b/toolchain/check/testdata/operators/overloaded/string_indexing_out_of_bounds.carbon @@ -0,0 +1,17 @@ +// Part of the Carbon Language project, under the Apache License v2.0 with LLVM +// Exceptions. See /LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// INCLUDE-FILE: toolchain/testing/testdata/min_prelude/full.carbon +// AUTOUPDATE +// TIP: To test this file alone, run: +// TIP: bazel test //toolchain/testing:file_test --test_arg=--file_tests=toolchain/check/testdata/operators/overloaded/string_indexing_out_of_bounds.carbon +// TIP: To dump output, run: +// TIP: bazel run //toolchain/testing:file_test -- --dump_output --file_tests=toolchain/check/testdata/operators/overloaded/string_indexing_out_of_bounds.carbon + +// --- fail_out_of_bounds.carbon + +// CHECK:STDERR: fail_out_of_bounds.carbon:[[@LINE+4]]:27: error: string index `4` is past the end of the string. [StringAtIndexOutOfBounds] +// CHECK:STDERR: var c: Core.Char = "Test"[4]; +// CHECK:STDERR: ^ +// CHECK:STDERR: +var c: Core.Char = "Test"[4]; diff --git a/toolchain/check/testdata/operators/overloaded/string_indexing_wrong_type.carbon b/toolchain/check/testdata/operators/overloaded/string_indexing_wrong_type.carbon new file mode 100644 index 0000000000000..5b86675cc1bfd --- /dev/null +++ b/toolchain/check/testdata/operators/overloaded/string_indexing_wrong_type.carbon @@ -0,0 +1,20 @@ +// Part of the Carbon Language project, under the Apache License v2.0 with LLVM +// Exceptions. See /LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// INCLUDE-FILE: toolchain/testing/testdata/min_prelude/full.carbon +// AUTOUPDATE +// TIP: To test this file alone, run: +// TIP: bazel test //toolchain/testing:file_test --test_arg=--file_tests=toolchain/check/testdata/operators/overloaded/string_indexing_wrong_type.carbon +// TIP: To dump output, run: +// TIP: bazel run //toolchain/testing:file_test -- --dump_output --file_tests=toolchain/check/testdata/operators/overloaded/string_indexing_wrong_type.carbon + +// --- fail_wrong_type.carbon + +fn TestWrongType() { + var x: i32 = 42; + // CHECK:STDERR: fail_wrong_type.carbon:[[@LINE+4]]:22: error: cannot access member of interface `Core.IndexWith(Core.IntLiteral)` in type `i32` that does not implement that interface [MissingImplInMemberAccess] + // CHECK:STDERR: var c: Core.Char = x[0]; + // CHECK:STDERR: ^~~~ + // CHECK:STDERR: + var c: Core.Char = x[0]; +} diff --git a/toolchain/diagnostics/diagnostic_kind.def b/toolchain/diagnostics/diagnostic_kind.def index bcac19f85eeac..627a1e9c056fa 100644 --- a/toolchain/diagnostics/diagnostic_kind.def +++ b/toolchain/diagnostics/diagnostic_kind.def @@ -397,6 +397,7 @@ CARBON_DIAGNOSTIC_KIND(AddrOnNonSelfParam) CARBON_DIAGNOSTIC_KIND(AddrOnNonPointerType) CARBON_DIAGNOSTIC_KIND(ArrayBoundTooLarge) CARBON_DIAGNOSTIC_KIND(ArrayBoundNegative) +CARBON_DIAGNOSTIC_KIND(ArrayIndexNegative) CARBON_DIAGNOSTIC_KIND(ArrayIndexOutOfBounds) CARBON_DIAGNOSTIC_KIND(ArrayInitFromLiteralArgCountMismatch) CARBON_DIAGNOSTIC_KIND(ArrayInitFromExprArgCountMismatch) @@ -450,6 +451,7 @@ CARBON_DIAGNOSTIC_KIND(NegativeIntInUnsignedType) CARBON_DIAGNOSTIC_KIND(NonConstantCallToCompTimeOnlyFunction) CARBON_DIAGNOSTIC_KIND(CompTimeOnlyFunctionHere) CARBON_DIAGNOSTIC_KIND(SelfOutsideImplicitParamList) +CARBON_DIAGNOSTIC_KIND(StringAtIndexOutOfBounds) CARBON_DIAGNOSTIC_KIND(StringLiteralTooLong) CARBON_DIAGNOSTIC_KIND(StringLiteralTypeIncomplete) CARBON_DIAGNOSTIC_KIND(StringLiteralTypeUnexpected) diff --git a/toolchain/lower/handle_call.cpp b/toolchain/lower/handle_call.cpp index 611c5d9723089..47a2123966d90 100644 --- a/toolchain/lower/handle_call.cpp +++ b/toolchain/lower/handle_call.cpp @@ -318,6 +318,34 @@ static auto HandleBuiltinCall(FunctionContext& context, SemIR::InstId inst_id, return; } + case SemIR::BuiltinFunctionKind::StringAt: { + auto string_inst_id = arg_ids[0]; + auto* string_arg = context.GetValue(string_inst_id); + + auto string_type_id = context.GetTypeIdOfInst(string_inst_id); + auto* string_type = context.GetType(string_type_id); + auto* string_value = + context.builder().CreateLoad(string_type, string_arg, "string.load"); + + auto* string_ptr_field = + context.builder().CreateExtractValue(string_value, {0}, "string.ptr"); + + auto* index_value = context.GetValue(arg_ids[1]); + + auto* char_ptr = context.builder().CreateInBoundsGEP( + llvm::Type::getInt8Ty(context.llvm_context()), string_ptr_field, + index_value, "string.char_ptr"); + + auto* char_i8 = context.builder().CreateLoad( + llvm::Type::getInt8Ty(context.llvm_context()), char_ptr, + "string.char"); + + context.SetLocal(inst_id, context.builder().CreateZExt( + char_i8, context.GetTypeOfInst(inst_id), + "string.char.zext")); + return; + } + case SemIR::BuiltinFunctionKind::TypeAnd: { context.SetLocal(inst_id, context.GetTypeAsValue()); return; diff --git a/toolchain/sem_ir/builtin_function_kind.cpp b/toolchain/sem_ir/builtin_function_kind.cpp index c6af7ca0a9d8c..c9b5d00685bae 100644 --- a/toolchain/sem_ir/builtin_function_kind.cpp +++ b/toolchain/sem_ir/builtin_function_kind.cpp @@ -177,6 +177,22 @@ struct AnyType { return true; } }; +// Constraint that checks if a type is Core.String. +struct CoreStringType { + static auto Check(const File& sem_ir, ValidateState& /*state*/, + TypeId type_id) -> bool { + auto type_inst_id = sem_ir.types().GetInstId(type_id); + auto class_type = sem_ir.insts().TryGetAs(type_inst_id); + if (!class_type) { + // Not a string. + return false; + } + + const auto& class_info = sem_ir.classes().Get(class_type->class_id); + + return sem_ir.names().GetFormatted(class_info.name_id).str() == "String"; + } +}; // Constraint that requires the type to be the type type. using Type = BuiltinType; @@ -322,6 +338,11 @@ constexpr BuiltinInfo PrintInt = { constexpr BuiltinInfo ReadChar = {"read.char", ValidateSignatureAnySizedInt>}; +// Gets a character from a string at the given index. +constexpr BuiltinInfo StringAt = { + "string.at", + ValidateSignatureAnySizedInt>}; + // Returns the `Core.CharLiteral` type. constexpr BuiltinInfo CharLiteralMakeType = {"char_literal.make_type", ValidateSignatureType>}; diff --git a/toolchain/sem_ir/builtin_function_kind.def b/toolchain/sem_ir/builtin_function_kind.def index d0b1306fffd5a..e5d3e74f179bf 100644 --- a/toolchain/sem_ir/builtin_function_kind.def +++ b/toolchain/sem_ir/builtin_function_kind.def @@ -30,6 +30,7 @@ CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(PrimitiveCopy) CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(PrintChar) CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(PrintInt) CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(ReadChar) +CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(StringAt) // Type factories. CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(CharLiteralMakeType)