Skip to content

Commit ce911a5

Browse files
committed
feat(graph, store): extend supported types for first and last aggregations
1 parent 2418aa1 commit ce911a5

File tree

8 files changed

+365
-34
lines changed

8 files changed

+365
-34
lines changed

graph/src/schema/input/mod.rs

Lines changed: 88 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,7 @@ pub struct Aggregate {
788788
}
789789

790790
impl Aggregate {
791-
fn new(_schema: &Schema, name: &str, field_type: &s::Type, dir: &s::Directive) -> Self {
791+
fn new(schema: &Schema, name: &str, field_type: &s::Type, dir: &s::Directive) -> Self {
792792
let func = dir
793793
.argument("fn")
794794
.unwrap()
@@ -818,7 +818,7 @@ impl Aggregate {
818818
arg,
819819
cumulative,
820820
field_type: field_type.clone(),
821-
value_type: field_type.get_base_type().parse().unwrap(),
821+
value_type: Field::scalar_value_type(schema, field_type),
822822
}
823823
}
824824

@@ -2366,27 +2366,63 @@ mod validations {
23662366
}
23672367
}
23682368

2369-
fn aggregate_fields_are_numbers(agg_type: &s::ObjectType, errors: &mut Vec<Err>) {
2369+
fn aggregate_field_types(
2370+
schema: &Schema,
2371+
agg_type: &s::ObjectType,
2372+
errors: &mut Vec<Err>,
2373+
) {
2374+
fn is_first_last(agg_directive: &s::Directive) -> bool {
2375+
match agg_directive.argument(kw::FUNC) {
2376+
Some(s::Value::Enum(func) | s::Value::String(func)) => {
2377+
func == AggregateFn::First.as_str()
2378+
|| func == AggregateFn::Last.as_str()
2379+
}
2380+
_ => false,
2381+
}
2382+
}
2383+
23702384
let errs = agg_type
23712385
.fields
23722386
.iter()
2373-
.filter(|field| field.find_directive(kw::AGGREGATE).is_some())
2374-
.map(|field| match field.field_type.value_type() {
2375-
Ok(vt) => {
2376-
if vt.is_numeric() {
2377-
Ok(())
2378-
} else {
2379-
Err(Err::NonNumericAggregate(
2387+
.filter_map(|field| {
2388+
field
2389+
.find_directive(kw::AGGREGATE)
2390+
.map(|agg_directive| (field, agg_directive))
2391+
})
2392+
.map(|(field, agg_directive)| {
2393+
let is_first_last = is_first_last(agg_directive);
2394+
2395+
match field.field_type.value_type() {
2396+
Ok(value_type) if value_type.is_numeric() => Ok(()),
2397+
Ok(ValueType::Bytes | ValueType::String) if is_first_last => Ok(()),
2398+
Ok(_) if is_first_last => Err(Err::InvalidFirstLastAggregate(
2399+
agg_type.name.clone(),
2400+
field.name.clone(),
2401+
)),
2402+
Ok(_) => Err(Err::NonNumericAggregate(
2403+
agg_type.name.to_owned(),
2404+
field.name.to_owned(),
2405+
)),
2406+
Err(_) => {
2407+
if is_first_last
2408+
&& schema
2409+
.entity_types
2410+
.iter()
2411+
.find(|entity_type| {
2412+
entity_type.name.eq(field.field_type.get_base_type())
2413+
})
2414+
.is_some()
2415+
{
2416+
return Ok(());
2417+
}
2418+
2419+
Err(Err::FieldTypeUnknown(
23802420
agg_type.name.to_owned(),
23812421
field.name.to_owned(),
2422+
field.field_type.get_base_type().to_owned(),
23822423
))
23832424
}
23842425
}
2385-
Err(_) => Err(Err::FieldTypeUnknown(
2386-
agg_type.name.to_owned(),
2387-
field.name.to_owned(),
2388-
field.field_type.get_base_type().to_owned(),
2389-
)),
23902426
})
23912427
.filter_map(|err| err.err());
23922428
errors.extend(errs);
@@ -2519,16 +2555,10 @@ mod validations {
25192555
continue;
25202556
}
25212557
};
2522-
let field_type = match field.field_type.value_type() {
2523-
Ok(field_type) => field_type,
2524-
Err(_) => {
2525-
errors.push(Err::NonNumericAggregate(
2526-
agg_type.name.to_owned(),
2527-
field.name.to_owned(),
2528-
));
2529-
continue;
2530-
}
2531-
};
2558+
2559+
let is_first_last =
2560+
matches!(func, AggregateFn::First | AggregateFn::Last);
2561+
25322562
// It would be nicer to use a proper struct here
25332563
// and have that implement
25342564
// `sqlexpr::ExprVisitor` but we need access to
@@ -2539,6 +2569,18 @@ mod validations {
25392569
let arg_type = match source.field(ident) {
25402570
Some(arg_field) => match arg_field.field_type.value_type() {
25412571
Ok(arg_type) if arg_type.is_numeric() => arg_type,
2572+
Ok(ValueType::Bytes | ValueType::String)
2573+
if is_first_last =>
2574+
{
2575+
return Ok(());
2576+
}
2577+
Err(_)
2578+
if is_first_last
2579+
&& arg_field.field_type.get_base_type()
2580+
== field.field_type.get_base_type() =>
2581+
{
2582+
return Ok(());
2583+
}
25422584
Ok(_) | Err(_) => {
25432585
return Err(Err::AggregationNonNumericArg(
25442586
agg_type.name.to_owned(),
@@ -2556,15 +2598,27 @@ mod validations {
25562598
));
25572599
}
25582600
};
2559-
if arg_type > field_type {
2560-
return Err(Err::AggregationNonMatchingArg(
2561-
agg_type.name.to_owned(),
2562-
field.name.to_owned(),
2563-
arg.to_owned(),
2564-
arg_type.to_str().to_owned(),
2565-
field_type.to_str().to_owned(),
2566-
));
2601+
2602+
match field.field_type.value_type() {
2603+
Ok(field_type) if field_type.is_numeric() => {
2604+
if arg_type > field_type {
2605+
return Err(Err::AggregationNonMatchingArg(
2606+
agg_type.name.to_owned(),
2607+
field.name.to_owned(),
2608+
arg.to_owned(),
2609+
arg_type.to_str().to_owned(),
2610+
field_type.to_str().to_owned(),
2611+
));
2612+
}
2613+
}
2614+
Ok(_) | Err(_) => {
2615+
return Err(Err::NonNumericAggregate(
2616+
agg_type.name.to_owned(),
2617+
field.name.to_owned(),
2618+
));
2619+
}
25672620
}
2621+
25682622
Ok(())
25692623
};
25702624
if let Err(mut errs) = sqlexpr::parse(arg, check_ident) {
@@ -2661,7 +2715,7 @@ mod validations {
26612715
errors.push(err);
26622716
}
26632717
no_derived_fields(agg_type, &mut errors);
2664-
aggregate_fields_are_numbers(agg_type, &mut errors);
2718+
aggregate_field_types(self, agg_type, &mut errors);
26652719
aggregate_directive(self, agg_type, &mut errors);
26662720
// check timeseries directive has intervals and args
26672721
aggregation_intervals(agg_type, &mut errors);

graph/src/schema/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ pub enum SchemaValidationError {
123123
TimestampFieldMissing(String),
124124
#[error("Aggregation {0}, field{1}: aggregates must use a numeric type, one of Int, Int8, BigInt, and BigDecimal")]
125125
NonNumericAggregate(String, String),
126+
#[error("Aggregation '{0}', field '{1}': first/last aggregates must use a numeric, byte array, string or a reference type")]
127+
InvalidFirstLastAggregate(String, String),
126128
#[error("Aggregation {0} is missing the `source` argument")]
127129
AggregationMissingSource(String),
128130
#[error(
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# fail: FieldTypeUnknown("Stats", "firstBlockNumber", "BlockNumber")
2+
3+
type Data @entity(timeseries: true) {
4+
id: Int8!
5+
timestamp: Timestamp!
6+
blockNumber: Int8!
7+
}
8+
9+
type Stats @aggregation(intervals: ["hour", "day"], source: "Data") {
10+
id: Int8!
11+
timestamp: Timestamp!
12+
13+
firstBlockNumber: BlockNumber! @aggregate(fn: "first", arg: "blockNumber")
14+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# fail: FieldTypeUnknown("Stats", "lastBlockNumber", "BlockNumber")
2+
3+
type Data @entity(timeseries: true) {
4+
id: Int8!
5+
timestamp: Timestamp!
6+
blockNumber: Int8!
7+
}
8+
9+
type Stats @aggregation(intervals: ["hour", "day"], source: "Data") {
10+
id: Int8!
11+
timestamp: Timestamp!
12+
13+
lastBlockNumber: BlockNumber! @aggregate(fn: "last", arg: "blockNumber")
14+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# valid: Non-numeric first and last aggregations
2+
3+
type EntityA @entity(immutable: true) {
4+
id: ID!
5+
}
6+
7+
type EntityB @entity(immutable: true) {
8+
id: Int8!
9+
}
10+
11+
type EntityC @entity(immutable: true) {
12+
id: Bytes!
13+
}
14+
15+
type Data @entity(timeseries: true) {
16+
id: Int8!
17+
timestamp: Timestamp!
18+
fieldA: EntityA!
19+
fieldB: EntityB!
20+
fieldC: EntityC!
21+
fieldD: String!
22+
fieldE: Bytes!
23+
}
24+
25+
type Stats @aggregation(intervals: ["hour", "day"], source: "Data") {
26+
id: Int8!
27+
timestamp: Timestamp!
28+
29+
firstA: EntityA! @aggregate(fn: "first", arg: "fieldA")
30+
lastA: EntityA! @aggregate(fn: "last", arg: "fieldA")
31+
32+
firstB: EntityB! @aggregate(fn: "first", arg: "fieldB")
33+
lastB: EntityB! @aggregate(fn: "last", arg: "fieldB")
34+
35+
firstC: EntityC! @aggregate(fn: "first", arg: "fieldC")
36+
lastC: EntityC! @aggregate(fn: "last", arg: "fieldC")
37+
38+
firstD: String! @aggregate(fn: "first", arg: "fieldD")
39+
lastD: String! @aggregate(fn: "last", arg: "fieldD")
40+
41+
firstE: Bytes! @aggregate(fn: "first", arg: "fieldE")
42+
lastE: Bytes! @aggregate(fn: "last", arg: "fieldE")
43+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
-- This file was generated by generate.sh in this directory
2+
set search_path = public;
3+
drop aggregate arg_min_text(text_and_value);
4+
drop aggregate arg_max_text(text_and_value);
5+
drop function arg_from_text_and_value(text_and_value);
6+
drop function arg_max_agg_text(text_and_value, text_and_value);
7+
drop function arg_min_agg_text(text_and_value, text_and_value);
8+
drop type text_and_value;
9+
drop aggregate arg_min_bytea(bytea_and_value);
10+
drop aggregate arg_max_bytea(bytea_and_value);
11+
drop function arg_from_bytea_and_value(bytea_and_value);
12+
drop function arg_max_agg_bytea(bytea_and_value, bytea_and_value);
13+
drop function arg_min_agg_bytea(bytea_and_value, bytea_and_value);
14+
drop type bytea_and_value;
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#! /bin/bash
2+
3+
# Generate up and down migrations to define arg_min and arg_max functions
4+
# for the types listed in `types`.
5+
#
6+
# The functions can all be used like
7+
#
8+
# select first_int4((arg, value)) from t
9+
#
10+
# and return the `arg int4` for the smallest value `value int8`. If there
11+
# are several rows with the smallest value, we try hard to return the first
12+
# one, but that also depends on how Postgres calculates these
13+
# aggregations. Note that the relation over which we are aggregating does
14+
# not need to be ordered.
15+
#
16+
# Unfortunately, it is not possible to do this generically, so we have to
17+
# monomorphize and define an aggregate for each data type that we want to
18+
# use. The `value` is always an `int8`
19+
#
20+
# If changes to these functions are needed, copy this script to a new
21+
# migration, change it and regenerate the up and down migrations
22+
23+
types="text bytea"
24+
dir=$(dirname $0)
25+
26+
read -d '' -r prelude <<'EOF'
27+
-- This file was generated by generate.sh in this directory
28+
set search_path = public;
29+
EOF
30+
31+
read -d '' -r up_template <<'EOF'
32+
create type public.@T@_and_value as (
33+
arg @T@,
34+
value int8
35+
);
36+
37+
create or replace function arg_min_agg_@T@ (a @T@_and_value, b @T@_and_value)
38+
returns @T@_and_value
39+
language sql immutable strict parallel safe as
40+
'select case when a.arg is null then b
41+
when b.arg is null then a
42+
when a.value <= b.value then a
43+
else b end';
44+
45+
create or replace function arg_max_agg_@T@ (a @T@_and_value, b @T@_and_value)
46+
returns @T@_and_value
47+
language sql immutable strict parallel safe as
48+
'select case when a.arg is null then b
49+
when b.arg is null then a
50+
when a.value > b.value then a
51+
else b end';
52+
53+
create function arg_from_@T@_and_value(a @T@_and_value)
54+
returns @T@
55+
language sql immutable strict parallel safe as
56+
'select a.arg';
57+
58+
create aggregate arg_min_@T@ (@T@_and_value) (
59+
sfunc = arg_min_agg_@T@,
60+
stype = @T@_and_value,
61+
finalfunc = arg_from_@T@_and_value,
62+
parallel = safe
63+
);
64+
65+
comment on aggregate arg_min_@T@(@T@_and_value) is
66+
'For ''select arg_min_@T@((arg, value)) from ..'' return the arg for the smallest value';
67+
68+
create aggregate arg_max_@T@ (@T@_and_value) (
69+
sfunc = arg_max_agg_@T@,
70+
stype = @T@_and_value,
71+
finalfunc = arg_from_@T@_and_value,
72+
parallel = safe
73+
);
74+
75+
comment on aggregate arg_max_@T@(@T@_and_value) is
76+
'For ''select arg_max_@T@((arg, value)) from ..'' return the arg for the largest value';
77+
EOF
78+
79+
read -d '' -r down_template <<'EOF'
80+
drop aggregate arg_min_@T@(@T@_and_value);
81+
drop aggregate arg_max_@T@(@T@_and_value);
82+
drop function arg_from_@T@_and_value(@T@_and_value);
83+
drop function arg_max_agg_@T@(@T@_and_value, @T@_and_value);
84+
drop function arg_min_agg_@T@(@T@_and_value, @T@_and_value);
85+
drop type @T@_and_value;
86+
EOF
87+
88+
echo "$prelude" > $dir/up.sql
89+
for typ in $types
90+
do
91+
echo "${up_template//@T@/$typ}" >> $dir/up.sql
92+
done
93+
94+
echo "$prelude" > $dir/down.sql
95+
for typ in $types
96+
do
97+
echo "${down_template//@T@/$typ}" >> $dir/down.sql
98+
done

0 commit comments

Comments
 (0)