Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 88 additions & 34 deletions graph/src/schema/input/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ pub struct Aggregate {
}

impl Aggregate {
fn new(_schema: &Schema, name: &str, field_type: &s::Type, dir: &s::Directive) -> Self {
fn new(schema: &Schema, name: &str, field_type: &s::Type, dir: &s::Directive) -> Self {
let func = dir
.argument("fn")
.unwrap()
Expand Down Expand Up @@ -818,7 +818,7 @@ impl Aggregate {
arg,
cumulative,
field_type: field_type.clone(),
value_type: field_type.get_base_type().parse().unwrap(),
value_type: Field::scalar_value_type(schema, field_type),
}
}

Expand Down Expand Up @@ -2366,27 +2366,63 @@ mod validations {
}
}

fn aggregate_fields_are_numbers(agg_type: &s::ObjectType, errors: &mut Vec<Err>) {
fn aggregate_field_types(
schema: &Schema,
agg_type: &s::ObjectType,
errors: &mut Vec<Err>,
) {
fn is_first_last(agg_directive: &s::Directive) -> bool {
match agg_directive.argument(kw::FUNC) {
Some(s::Value::Enum(func) | s::Value::String(func)) => {
func == AggregateFn::First.as_str()
|| func == AggregateFn::Last.as_str()
}
_ => false,
}
}

let errs = agg_type
.fields
.iter()
.filter(|field| field.find_directive(kw::AGGREGATE).is_some())
.map(|field| match field.field_type.value_type() {
Ok(vt) => {
if vt.is_numeric() {
Ok(())
} else {
Err(Err::NonNumericAggregate(
.filter_map(|field| {
field
.find_directive(kw::AGGREGATE)
.map(|agg_directive| (field, agg_directive))
})
.map(|(field, agg_directive)| {
let is_first_last = is_first_last(agg_directive);

match field.field_type.value_type() {
Ok(value_type) if value_type.is_numeric() => Ok(()),
Ok(ValueType::Bytes | ValueType::String) if is_first_last => Ok(()),
Ok(_) if is_first_last => Err(Err::InvalidFirstLastAggregate(
agg_type.name.clone(),
field.name.clone(),
)),
Ok(_) => Err(Err::NonNumericAggregate(
agg_type.name.to_owned(),
field.name.to_owned(),
)),
Err(_) => {
if is_first_last
&& schema
.entity_types
.iter()
.find(|entity_type| {
entity_type.name.eq(field.field_type.get_base_type())
})
.is_some()
{
return Ok(());
}

Err(Err::FieldTypeUnknown(
agg_type.name.to_owned(),
field.name.to_owned(),
field.field_type.get_base_type().to_owned(),
))
}
}
Err(_) => Err(Err::FieldTypeUnknown(
agg_type.name.to_owned(),
field.name.to_owned(),
field.field_type.get_base_type().to_owned(),
)),
})
.filter_map(|err| err.err());
errors.extend(errs);
Expand Down Expand Up @@ -2519,16 +2555,10 @@ mod validations {
continue;
}
};
let field_type = match field.field_type.value_type() {
Ok(field_type) => field_type,
Err(_) => {
errors.push(Err::NonNumericAggregate(
agg_type.name.to_owned(),
field.name.to_owned(),
));
continue;
}
};

let is_first_last =
matches!(func, AggregateFn::First | AggregateFn::Last);

// It would be nicer to use a proper struct here
// and have that implement
// `sqlexpr::ExprVisitor` but we need access to
Expand All @@ -2539,6 +2569,18 @@ mod validations {
let arg_type = match source.field(ident) {
Some(arg_field) => match arg_field.field_type.value_type() {
Ok(arg_type) if arg_type.is_numeric() => arg_type,
Ok(ValueType::Bytes | ValueType::String)
if is_first_last =>
{
return Ok(());
}
Err(_)
if is_first_last
&& arg_field.field_type.get_base_type()
== field.field_type.get_base_type() =>
{
return Ok(());
}
Ok(_) | Err(_) => {
return Err(Err::AggregationNonNumericArg(
agg_type.name.to_owned(),
Expand All @@ -2556,15 +2598,27 @@ mod validations {
));
}
};
if arg_type > field_type {
return Err(Err::AggregationNonMatchingArg(
agg_type.name.to_owned(),
field.name.to_owned(),
arg.to_owned(),
arg_type.to_str().to_owned(),
field_type.to_str().to_owned(),
));

match field.field_type.value_type() {
Ok(field_type) if field_type.is_numeric() => {
if arg_type > field_type {
return Err(Err::AggregationNonMatchingArg(
agg_type.name.to_owned(),
field.name.to_owned(),
arg.to_owned(),
arg_type.to_str().to_owned(),
field_type.to_str().to_owned(),
));
}
}
Ok(_) | Err(_) => {
return Err(Err::NonNumericAggregate(
agg_type.name.to_owned(),
field.name.to_owned(),
));
}
}

Ok(())
};
if let Err(mut errs) = sqlexpr::parse(arg, check_ident) {
Expand Down Expand Up @@ -2661,7 +2715,7 @@ mod validations {
errors.push(err);
}
no_derived_fields(agg_type, &mut errors);
aggregate_fields_are_numbers(agg_type, &mut errors);
aggregate_field_types(self, agg_type, &mut errors);
aggregate_directive(self, agg_type, &mut errors);
// check timeseries directive has intervals and args
aggregation_intervals(agg_type, &mut errors);
Expand Down
2 changes: 2 additions & 0 deletions graph/src/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ pub enum SchemaValidationError {
TimestampFieldMissing(String),
#[error("Aggregation {0}, field{1}: aggregates must use a numeric type, one of Int, Int8, BigInt, and BigDecimal")]
NonNumericAggregate(String, String),
#[error("Aggregation '{0}', field '{1}': first/last aggregates must use a numeric, byte array, string or a reference type")]
InvalidFirstLastAggregate(String, String),
#[error("Aggregation {0} is missing the `source` argument")]
AggregationMissingSource(String),
#[error(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# fail: FieldTypeUnknown("Stats", "firstBlockNumber", "BlockNumber")

type Data @entity(timeseries: true) {
id: Int8!
timestamp: Timestamp!
blockNumber: Int8!
}

type Stats @aggregation(intervals: ["hour", "day"], source: "Data") {
id: Int8!
timestamp: Timestamp!

firstBlockNumber: BlockNumber! @aggregate(fn: "first", arg: "blockNumber")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# fail: FieldTypeUnknown("Stats", "lastBlockNumber", "BlockNumber")

type Data @entity(timeseries: true) {
id: Int8!
timestamp: Timestamp!
blockNumber: Int8!
}

type Stats @aggregation(intervals: ["hour", "day"], source: "Data") {
id: Int8!
timestamp: Timestamp!

lastBlockNumber: BlockNumber! @aggregate(fn: "last", arg: "blockNumber")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# valid: Non-numeric first and last aggregations

type EntityA @entity(immutable: true) {
id: ID!
}

type EntityB @entity(immutable: true) {
id: Int8!
}

type EntityC @entity(immutable: true) {
id: Bytes!
}

type Data @entity(timeseries: true) {
id: Int8!
timestamp: Timestamp!
fieldA: EntityA!
fieldB: EntityB!
fieldC: EntityC!
fieldD: String!
fieldE: Bytes!
}

type Stats @aggregation(intervals: ["hour", "day"], source: "Data") {
id: Int8!
timestamp: Timestamp!

firstA: EntityA! @aggregate(fn: "first", arg: "fieldA")
lastA: EntityA! @aggregate(fn: "last", arg: "fieldA")

firstB: EntityB! @aggregate(fn: "first", arg: "fieldB")
lastB: EntityB! @aggregate(fn: "last", arg: "fieldB")

firstC: EntityC! @aggregate(fn: "first", arg: "fieldC")
lastC: EntityC! @aggregate(fn: "last", arg: "fieldC")

firstD: String! @aggregate(fn: "first", arg: "fieldD")
lastD: String! @aggregate(fn: "last", arg: "fieldD")

firstE: Bytes! @aggregate(fn: "first", arg: "fieldE")
lastE: Bytes! @aggregate(fn: "last", arg: "fieldE")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
-- This file was generated by generate.sh in this directory
set search_path = public;
drop aggregate arg_min_text(text_and_value);
drop aggregate arg_max_text(text_and_value);
drop function arg_from_text_and_value(text_and_value);
drop function arg_max_agg_text(text_and_value, text_and_value);
drop function arg_min_agg_text(text_and_value, text_and_value);
drop type text_and_value;
drop aggregate arg_min_bytea(bytea_and_value);
drop aggregate arg_max_bytea(bytea_and_value);
drop function arg_from_bytea_and_value(bytea_and_value);
drop function arg_max_agg_bytea(bytea_and_value, bytea_and_value);
drop function arg_min_agg_bytea(bytea_and_value, bytea_and_value);
drop type bytea_and_value;
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
-- This file was generated by generate.sh in this directory
set search_path = public;
create type public.text_and_value as (
arg text,
value int8
);

create or replace function arg_min_agg_text (a text_and_value, b text_and_value)
returns text_and_value
language sql immutable strict parallel safe as
'select case when a.arg is null then b
when b.arg is null then a
when a.value <= b.value then a
else b end';

create or replace function arg_max_agg_text (a text_and_value, b text_and_value)
returns text_and_value
language sql immutable strict parallel safe as
'select case when a.arg is null then b
when b.arg is null then a
when a.value > b.value then a
else b end';

create function arg_from_text_and_value(a text_and_value)
returns text
language sql immutable strict parallel safe as
'select a.arg';

create aggregate arg_min_text (text_and_value) (
sfunc = arg_min_agg_text,
stype = text_and_value,
finalfunc = arg_from_text_and_value,
parallel = safe
);

comment on aggregate arg_min_text(text_and_value) is
'For ''select arg_min_text((arg, value)) from ..'' return the arg for the smallest value';

create aggregate arg_max_text (text_and_value) (
sfunc = arg_max_agg_text,
stype = text_and_value,
finalfunc = arg_from_text_and_value,
parallel = safe
);

comment on aggregate arg_max_text(text_and_value) is
'For ''select arg_max_text((arg, value)) from ..'' return the arg for the largest value';
create type public.bytea_and_value as (
arg bytea,
value int8
);

create or replace function arg_min_agg_bytea (a bytea_and_value, b bytea_and_value)
returns bytea_and_value
language sql immutable strict parallel safe as
'select case when a.arg is null then b
when b.arg is null then a
when a.value <= b.value then a
else b end';

create or replace function arg_max_agg_bytea (a bytea_and_value, b bytea_and_value)
returns bytea_and_value
language sql immutable strict parallel safe as
'select case when a.arg is null then b
when b.arg is null then a
when a.value > b.value then a
else b end';

create function arg_from_bytea_and_value(a bytea_and_value)
returns bytea
language sql immutable strict parallel safe as
'select a.arg';

create aggregate arg_min_bytea (bytea_and_value) (
sfunc = arg_min_agg_bytea,
stype = bytea_and_value,
finalfunc = arg_from_bytea_and_value,
parallel = safe
);

comment on aggregate arg_min_bytea(bytea_and_value) is
'For ''select arg_min_bytea((arg, value)) from ..'' return the arg for the smallest value';

create aggregate arg_max_bytea (bytea_and_value) (
sfunc = arg_max_agg_bytea,
stype = bytea_and_value,
finalfunc = arg_from_bytea_and_value,
parallel = safe
);

comment on aggregate arg_max_bytea(bytea_and_value) is
'For ''select arg_max_bytea((arg, value)) from ..'' return the arg for the largest value';