Skip to content

Commit 53c99ab

Browse files
authored
Allow float division in compile-time constants (#1263)
1 parent e241250 commit 53c99ab

File tree

12 files changed

+119
-27
lines changed

12 files changed

+119
-27
lines changed

compiler/builders/any_builder.jou

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,14 +305,14 @@ class AnyBuilder:
305305
return result
306306

307307
@public
308-
def float_or_double(self, t: Type*, string: byte*) -> AnyBuilderValue:
308+
def float_or_double(self, t: Type*, dbl: double) -> AnyBuilderValue:
309309
result = AnyBuilderValue{}
310310
if self.lbuilder != NULL:
311-
result.lvalue = self.lbuilder.float_or_double(t, string)
311+
result.lvalue = self.lbuilder.float_or_double(t, dbl)
312312
if self.ubuilder != NULL:
313-
result.uvalue = self.ubuilder.float_or_double(t, string)
313+
result.uvalue = self.ubuilder.float_or_double(t, dbl)
314314
if self.hbuilder != NULL:
315-
result.hvalue = self.hbuilder.float_or_double(t, string)
315+
result.hvalue = self.hbuilder.float_or_double(t, dbl)
316316
return result
317317

318318
@public

compiler/builders/ast_to_builder.jou

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ class AstToBuilder:
396396
case ConstantKind.Bool:
397397
return self.builder.boolean(constant.boolean)
398398
case ConstantKind.Float | ConstantKind.Double:
399-
return self.builder.float_or_double(constant.get_type(), constant.float_or_double_text)
399+
return self.builder.float_or_double(constant.get_type(), constant.float_or_double_value)
400400
case ConstantKind.Null:
401401
return self.builder.zero_of_type(void_ptr_type())
402402
case ConstantKind.EnumMember:

compiler/builders/hash_builder.jou

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,9 @@ class HBuilder:
247247
return self.new_value(t)
248248

249249
@public
250-
def float_or_double(self, t: Type*, string: byte*) -> HBuilderValue:
250+
def float_or_double(self, t: Type*, dbl: double) -> HBuilderValue:
251251
self.hash.add_string("float_or_double")
252-
self.hash.add_string(string)
252+
self.hash.add_bytes(&dbl as byte*, sizeof(dbl))
253253
return self.new_value(t)
254254

255255
@public

compiler/builders/llvm_builder.jou

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -603,11 +603,11 @@ class LBuilder:
603603
}
604604

605605
@public
606-
def float_or_double(self, t: Type*, string: byte*) -> LBuilderValue:
606+
def float_or_double(self, t: Type*, dbl: double) -> LBuilderValue:
607607
assert t.kind == TypeKind.FloatingPoint
608608
return LBuilderValue{
609609
type = t,
610-
llvm_value = LLVMConstRealOfString(type_to_llvm(self.state, t), string)
610+
llvm_value = LLVMConstReal(type_to_llvm(self.state, t), dbl)
611611
}
612612

613613
@public

compiler/builders/uvg_builder.jou

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ class UBuilder:
156156
return ANONYMOUS_VALUE_ID
157157

158158
@public
159-
def float_or_double(self, t: Type*, string: byte*) -> int:
159+
def float_or_double(self, t: Type*, dbl: double) -> int:
160160
return ANONYMOUS_VALUE_ID
161161

162162
@public

compiler/constants.jou

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ import "stdlib/mem.jou"
88
import "./types.jou"
99

1010

11+
# TODO: this function doesn't belong here
12+
declare atof(string: byte*) -> double
13+
14+
1115
def print_string(s: byte*) -> None:
1216
putchar('"')
1317
for i = 0; s[i] != '\0'; i++:
@@ -47,14 +51,28 @@ class EnumMemberConstant:
4751
enumtype: Type*
4852
memberidx: int
4953

54+
# Needed only for debugging
55+
def print_double(n: double) -> None:
56+
# 17 significant digits is enough to print any double, but it's too
57+
# much for some doubles. For example, with 17 significant digits, 3.14
58+
# becomes 3.1400000000000001.
59+
#
60+
# The same logic is in stdlib/json.jou at the time of writing this.
61+
buf: byte[64] = ""
62+
snprintf(buf, sizeof(buf), "%.16g", n)
63+
if atof(buf) != n:
64+
snprintf(buf, sizeof(buf), "%.17g", n)
65+
printf("%s", buf)
66+
67+
5068
@public
5169
class Constant:
5270
kind: ConstantKind
5371
union:
5472
integer: IntegerConstant
5573
pointer_string: byte*
5674
array_of_bytes: List[byte] # may contain zero bytes
57-
float_or_double_text: byte[100] # convenient because LLVM wants a string anyway
75+
float_or_double_value: double
5876
boolean: bool
5977
enum_member: EnumMemberConstant
6078
array_elements: List[Constant]
@@ -70,9 +88,13 @@ class Constant:
7088
else:
7189
printf("False\n")
7290
case ConstantKind.Float:
73-
printf("float %s\n", self.float_or_double_text)
91+
printf("float ")
92+
print_double(self.float_or_double_value)
93+
printf("\n")
7494
case ConstantKind.Double:
75-
printf("double %s\n", self.float_or_double_text)
95+
printf("double ")
96+
print_double(self.float_or_double_value)
97+
printf("\n")
7698
case ConstantKind.Integer:
7799
if self.integer.is_signed:
78100
signed_or_unsigned = "signed"

compiler/evaluate.jou

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -210,20 +210,36 @@ def evaluate_constant_expression(jou_file: JouFile*, expr: AstExpression*, resul
210210
result.integer.value *= -1
211211
return True
212212
case TypeKind.FloatingPoint:
213-
old_string: byte[100] = result.float_or_double_text
214-
if old_string[0] == '-':
215-
# Remove minus sign in beginning: -(-1.0) becomes 1.0
216-
strcpy(result.float_or_double_text, &old_string[1])
217-
return True
218-
elif strlen(old_string) + 1 < sizeof(result.float_or_double_text):
219-
# Add minus to beginning
220-
sprintf(result.float_or_double_text, "-%s", old_string)
221-
return True
222-
else:
223-
return False
213+
result.float_or_double_value *= -1
214+
return True
224215
case _:
225216
return False
226217

218+
# Evaluate float division in some very limited cases.
219+
# Mostly for defining "INF = 1.0 / 0.0" and similar constants.
220+
#
221+
# Note: This evaluates the division on the computer that is compiling,
222+
# not on the target. So if floating point works differently on the
223+
# target computer, this may be off. Currently that is not the case with
224+
# any of our supported compilation targets.
225+
case AstExpressionKind.Div:
226+
if type_hint != float_type() and type_hint != double_type():
227+
return False
228+
229+
lhs_rhs: Constant[2]
230+
for i = 0; i < 2; i++:
231+
if not evaluate_constant_expression(jou_file, &expr.operands[i], &lhs_rhs[i], type_hint):
232+
return False
233+
if lhs_rhs[i].get_type() != type_hint:
234+
return False
235+
236+
assert lhs_rhs[0].kind == lhs_rhs[1].kind
237+
*result = Constant{
238+
kind = lhs_rhs[0].kind,
239+
float_or_double_value = lhs_rhs[0].float_or_double_value / lhs_rhs[1].float_or_double_value,
240+
}
241+
return True
242+
227243
case AstExpressionKind.Array:
228244
if type_hint != NULL and type_hint.kind == TypeKind.Array:
229245
item_type_hint = type_hint.array.item_type

compiler/llvm.jou

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,8 @@ declare LLVMGetUndef(Ty: LLVMType*) -> LLVMValue*
355355
@public
356356
declare LLVMConstInt(IntTy: LLVMType*, N: int64, SignExtend: int) -> LLVMValue*
357357
@public
358+
declare LLVMConstReal(RealTy: LLVMType*, N: double) -> LLVMValue*
359+
@public
358360
declare LLVMConstRealOfString(RealTy: LLVMType*, Text: byte*) -> LLVMValue*
359361
@public
360362
declare LLVMConstStringInContext(C: LLVMContext*, Str: byte*, Length: uint32, DontNullTerminate: int) -> LLVMValue*

compiler/parser.jou

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ import "./constants.jou"
1515
import "./types.jou"
1616

1717

18+
# TODO: this function doesn't belong here
19+
declare atof(string: byte*) -> double
20+
21+
1822
# arity = number of operands, e.g. 2 for a binary operator such as "+"
1923
#
2024
# This cannot be used for ++ and --, because with them we can't know the kind from
@@ -685,10 +689,10 @@ class Parser:
685689
expr.constant = int_constant(uint_type(8), (self.tokens++).integer_value as int64)
686690
case TokenKind.Float:
687691
expr.kind = AstExpressionKind.Constant
688-
expr.constant = Constant{kind = ConstantKind.Float, float_or_double_text = (self.tokens++).short_string}
692+
expr.constant = Constant{kind = ConstantKind.Float, float_or_double_value = atof((self.tokens++).short_string)}
689693
case TokenKind.Double:
690694
expr.kind = AstExpressionKind.Constant
691-
expr.constant = Constant{kind = ConstantKind.Double, float_or_double_text = (self.tokens++).short_string}
695+
expr.constant = Constant{kind = ConstantKind.Double, float_or_double_value = atof((self.tokens++).short_string)}
692696
case TokenKind.String:
693697
expr.kind = AstExpressionKind.Constant
694698
expr.constant = Constant{kind = ConstantKind.PointerString, pointer_string = strdup((self.tokens++).long_string)}

tests/should_succeed/math_test.jou

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
import "stdlib/math.jou"
22
import "stdlib/io.jou"
33

4+
# TODO: move these to stdlib/math.jou
5+
#
6+
# Here are some more readable ways to construct non-finite values. Note that
7+
# there are multiple different NaN values (see also: nan boxing), and the NaN
8+
# constructed here is just one of them.
9+
#
10+
# Like M_PI and M_E, these are doubles. Use "as float" as needed.
11+
@public
12+
const INFINITY: double = 1.0 / 0.0
13+
@public
14+
const NAN: double = 0.0 / 0.0
15+
416
def main() -> int:
517
printf("%f\n", M_PI) # Output: 3.141593
618
printf("%d\n", M_PI == acos(-1)) # Output: 1
@@ -18,23 +30,41 @@ def main() -> int:
1830
printf("%f\n", fmax(1.2, 3.4)) # Output: 3.400000
1931
printf("%f\n", fmaxf(1.2 as float, 3.4 as float)) # Output: 3.400000
2032

21-
test_doubles = [123.0, -123.0, 1.0 / 0.0, -1.0 / 0.0, 0.0 / 0.0, -0.0 / 0.0]
33+
test_doubles = [
34+
0.0, -0.0, 123.0, -123.0,
35+
# Parentheses are intentionally a bit weirdly here. It shouldn't matter
36+
# how they are and I want to test that.
37+
INFINITY, 1.0/0.0, -INFINITY, (-1.0)/0.0,
38+
NAN, 0.0/0.0, -NAN, 0.0/(-0.0),
39+
]
2240

2341
# Output: 1 0 0
2442
# Output: 1 0 0
43+
# Output: 1 0 0
44+
# Output: 1 0 0
45+
# Output: 0 1 0
2546
# Output: 0 1 0
2647
# Output: 0 1 0
48+
# Output: 0 1 0
49+
# Output: 0 0 1
50+
# Output: 0 0 1
2751
# Output: 0 0 1
2852
# Output: 0 0 1
2953
for d = &test_doubles[0]; d < &test_doubles[array_count(test_doubles)]; d++:
3054
printf("%d %d %d\n", isfinite(*d), isinf(*d), isnan(*d))
3155

3256
# Output: 1 0 0
3357
# Output: 1 0 0
58+
# Output: 1 0 0
59+
# Output: 1 0 0
60+
# Output: 0 1 0
61+
# Output: 0 1 0
3462
# Output: 0 1 0
3563
# Output: 0 1 0
3664
# Output: 0 0 1
3765
# Output: 0 0 1
66+
# Output: 0 0 1
67+
# Output: 0 0 1
3868
for d = &test_doubles[0]; d < &test_doubles[array_count(test_doubles)]; d++:
3969
printf("%d %d %d\n", isfinite(*d as float), isinf(*d as float), isnan(*d as float))
4070

@@ -106,4 +136,10 @@ def main() -> int:
106136
printf("%f\n", erf(2)) # Output: 0.995322
107137
printf("%f\n", erff(2)) # Output: 0.995322
108138

139+
# This wouldn't work if INFINITY is a float. Then the compiler would infer
140+
# that my_inf is a float and refuse to assign a double to it.
141+
my_inf = INFINITY
142+
my_inf = 12.34
143+
printf("%f\n", my_inf) # Output: 12.340000
144+
109145
return 0

0 commit comments

Comments
 (0)