Skip to content

Commit 5b72dbf

Browse files
martinsander00hantangwangd
authored andcommitted
Fix precision loss in parse_duration for large millisecond values
Fixes #25340
1 parent 063d3c1 commit 5b72dbf

File tree

3 files changed

+152
-2
lines changed

3 files changed

+152
-2
lines changed

presto-main-base/src/main/java/com/facebook/presto/operator/scalar/DateTimeFunctions.java

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,11 @@
3838
import org.joda.time.format.DateTimeFormatterBuilder;
3939
import org.joda.time.format.ISODateTimeFormat;
4040

41+
import java.math.BigDecimal;
4142
import java.util.Locale;
4243
import java.util.concurrent.TimeUnit;
44+
import java.util.regex.Matcher;
45+
import java.util.regex.Pattern;
4346

4447
import static com.facebook.presto.common.type.DateTimeEncoding.packDateTimeWithZone;
4548
import static com.facebook.presto.common.type.DateTimeEncoding.unpackMillisUtc;
@@ -83,6 +86,7 @@ public final class DateTimeFunctions
8386
private static final DateTimeField MONTH_OF_YEAR = UTC_CHRONOLOGY.monthOfYear();
8487
private static final DateTimeField QUARTER = QUARTER_OF_YEAR.getField(UTC_CHRONOLOGY);
8588
private static final DateTimeField YEAR = UTC_CHRONOLOGY.year();
89+
private static final Pattern PATTERN = Pattern.compile("^\\s*(\\d+(?:\\.\\d+)?)\\s*([a-zA-Z]+)\\s*$");
8690
private static final int MILLISECONDS_IN_SECOND = 1000;
8791
private static final int MILLISECONDS_IN_MINUTE = 60 * MILLISECONDS_IN_SECOND;
8892
private static final int MILLISECONDS_IN_HOUR = 60 * MILLISECONDS_IN_MINUTE;
@@ -1437,14 +1441,53 @@ else if (character == '%') {
14371441
@SqlType(StandardTypes.INTERVAL_DAY_TO_SECOND)
14381442
public static long parseDuration(@SqlType("varchar(x)") Slice duration)
14391443
{
1444+
String durationStr = duration.toStringUtf8();
1445+
1446+
if (durationStr.isEmpty()) {
1447+
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "duration is empty");
1448+
}
1449+
14401450
try {
1441-
return Duration.valueOf(duration.toStringUtf8()).toMillis();
1451+
Matcher matcher = PATTERN.matcher(durationStr);
1452+
1453+
if (!matcher.matches()) {
1454+
throw new PrestoException(INVALID_FUNCTION_ARGUMENT,
1455+
"duration is not a valid data duration string: " + durationStr);
1456+
}
1457+
1458+
BigDecimal value = new BigDecimal(matcher.group(1));
1459+
TimeUnit timeUnit = Duration.valueOfTimeUnit(matcher.group(2));
1460+
1461+
return value.multiply(millisPerTimeUnit(timeUnit))
1462+
.add(BigDecimal.valueOf(0.5)).longValue();
14421463
}
1443-
catch (IllegalArgumentException e) {
1464+
catch (IllegalArgumentException | ArithmeticException e) {
14441465
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, e);
14451466
}
14461467
}
14471468

1469+
private static BigDecimal millisPerTimeUnit(TimeUnit timeUnit)
1470+
{
1471+
switch (timeUnit) {
1472+
case NANOSECONDS:
1473+
return new BigDecimal("0.000001");
1474+
case MICROSECONDS:
1475+
return new BigDecimal("0.001");
1476+
case MILLISECONDS:
1477+
return BigDecimal.ONE;
1478+
case SECONDS:
1479+
return BigDecimal.valueOf(1000);
1480+
case MINUTES:
1481+
return BigDecimal.valueOf(60_000);
1482+
case HOURS:
1483+
return BigDecimal.valueOf(3_600_000);
1484+
case DAYS:
1485+
return BigDecimal.valueOf(86_400_000);
1486+
default:
1487+
throw new AssertionError("Unknown TimeUnit: " + timeUnit);
1488+
}
1489+
}
1490+
14481491
private static long timeAtTimeZone(SqlFunctionProperties properties, long timeWithTimeZone, TimeZoneKey timeZoneKey)
14491492
{
14501493
DateTimeZone sourceTimeZone = getDateTimeZone(unpackZoneKey(timeWithTimeZone));
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.presto.operator.scalar;
15+
16+
import io.airlift.slice.Slice;
17+
import io.airlift.slice.Slices;
18+
import io.airlift.units.Duration;
19+
import org.openjdk.jmh.annotations.Benchmark;
20+
import org.openjdk.jmh.annotations.BenchmarkMode;
21+
import org.openjdk.jmh.annotations.Fork;
22+
import org.openjdk.jmh.annotations.Measurement;
23+
import org.openjdk.jmh.annotations.Mode;
24+
import org.openjdk.jmh.annotations.OutputTimeUnit;
25+
import org.openjdk.jmh.annotations.Param;
26+
import org.openjdk.jmh.annotations.Scope;
27+
import org.openjdk.jmh.annotations.Setup;
28+
import org.openjdk.jmh.annotations.State;
29+
import org.openjdk.jmh.annotations.Warmup;
30+
import org.openjdk.jmh.infra.Blackhole;
31+
import org.openjdk.jmh.runner.Runner;
32+
import org.openjdk.jmh.runner.RunnerException;
33+
import org.openjdk.jmh.runner.options.Options;
34+
import org.openjdk.jmh.runner.options.OptionsBuilder;
35+
36+
import java.util.Random;
37+
import java.util.concurrent.TimeUnit;
38+
39+
@State(Scope.Benchmark)
40+
@OutputTimeUnit(TimeUnit.NANOSECONDS)
41+
@Fork(value = 1, jvmArgs = {"-Xms2G", "-Xmx2G"})
42+
@Warmup(iterations = 5, time = 1)
43+
@Measurement(iterations = 10, time = 1)
44+
@BenchmarkMode(Mode.AverageTime)
45+
public class BenchmarkDateTimeFunctions
46+
{
47+
@Param({"ns", "us", "ms", "s", "m", "h", "d"})
48+
private String unit = "ns";
49+
50+
private Random random = new Random();
51+
52+
@Setup
53+
public void setup()
54+
{
55+
random = new Random(42); // Fixed seed for reproducibility
56+
}
57+
58+
@Benchmark
59+
public void testBaseline(Blackhole bh)
60+
{
61+
int v1 = random.nextInt(10000);
62+
int v2 = random.nextInt(10000);
63+
Slice value = Slices.utf8Slice(v1 + "." + v2 + " " + unit);
64+
bh.consume(value.toStringUtf8());
65+
}
66+
67+
@Benchmark
68+
public void testUseBigDecimal(Blackhole bh)
69+
{
70+
int v1 = random.nextInt(10000);
71+
int v2 = random.nextInt(10000);
72+
Slice value = Slices.utf8Slice(v1 + "." + v2 + " " + unit);
73+
bh.consume(DateTimeFunctions.parseDuration(value));
74+
}
75+
76+
@Benchmark
77+
public void testUseDouble(Blackhole bh)
78+
{
79+
int v1 = random.nextInt(10000);
80+
int v2 = random.nextInt(10000);
81+
Slice value = Slices.utf8Slice(v1 + "." + v2 + " " + unit);
82+
bh.consume(Duration.valueOf(value.toStringUtf8()).toMillis());
83+
}
84+
85+
public static void main(String[] args)
86+
throws RunnerException
87+
{
88+
Options opt = new OptionsBuilder()
89+
.include(BenchmarkDateTimeFunctions.class.getSimpleName())
90+
.build();
91+
92+
new Runner(opt).run();
93+
}
94+
}

presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestDateTimeFunctionsBase.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,10 +1226,23 @@ public void testParseDuration()
12261226
assertFunction("parse_duration('1234.567h')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(51, 10, 34, 1, 200));
12271227
assertFunction("parse_duration('1234.567d')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(1234, 13, 36, 28, 800));
12281228

1229+
// trailing spaces
1230+
assertFunction("parse_duration('1234 ns ')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 0, 0));
1231+
assertFunction("parse_duration('1234 us ')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 0, 1));
1232+
assertFunction("parse_duration('1234ms ')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 1, 234));
1233+
12291234
// invalid function calls
12301235
assertInvalidFunction("parse_duration('')", "duration is empty");
12311236
assertInvalidFunction("parse_duration('1f')", "Unknown time unit: f");
12321237
assertInvalidFunction("parse_duration('abc')", "duration is not a valid data duration string: abc");
1238+
1239+
// long milliseconds edge cases
1240+
assertFunction("parse_duration('7702741401940153ms')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(89152099, 13, 25, 40, 153));
1241+
assertFunction("parse_duration('9117756383778565ms')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(105529587, 18, 36, 18, 565));
1242+
1243+
// Test precision for large values with fractional seconds
1244+
assertFunction("parse_duration('7702741401940.153s')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(89152099, 13, 25, 40, 153));
1245+
assertFunction("parse_duration('7702741401940.153 s')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(89152099, 13, 25, 40, 153));
12331246
}
12341247

12351248
@Test

0 commit comments

Comments
 (0)