Skip to content

Commit 5524089

Browse files
authored
Add JWT array payload support for authorization (#1434)
* Add JWT array payload support for authorization Signed-off-by: Marvin Froeder <[email protected]> * Add FIXME comment to ArrayContainsSqlTranslation for unwrapCast removal Signed-off-by: Marvin Froeder <[email protected]> * Update functionParameterTest.txt snapshot Signed-off-by: Marvin Froeder <[email protected]> --------- Signed-off-by: Marvin Froeder <[email protected]>
1 parent 5f99d08 commit 5524089

File tree

11 files changed

+310
-3
lines changed

11 files changed

+310
-3
lines changed

sqrl-planner/src/main/java/com/datasqrl/function/PgSpecificOperatorTable.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
import static com.datasqrl.function.CalciteFunctionUtil.lightweightOp;
1919

2020
import org.apache.calcite.sql.SqlBinaryOperator;
21+
import org.apache.calcite.sql.SqlCall;
2122
import org.apache.calcite.sql.SqlKind;
2223
import org.apache.calcite.sql.SqlUnresolvedFunction;
24+
import org.apache.calcite.sql.SqlWriter;
2325
import org.apache.calcite.sql.type.ReturnTypes;
2426
import org.apache.calcite.sql.type.SqlTypeName;
2527

@@ -67,4 +69,27 @@ public class PgSpecificOperatorTable {
6769
ReturnTypes.explicit(SqlTypeName.ANY),
6870
null,
6971
null);
72+
73+
public static final SqlBinaryOperator EqualsAny =
74+
new SqlBinaryOperator(
75+
"= ANY",
76+
SqlKind.OTHER_FUNCTION,
77+
22,
78+
true,
79+
ReturnTypes.explicit(SqlTypeName.BOOLEAN),
80+
null,
81+
null) {
82+
83+
@Override
84+
public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) {
85+
var left = call.operand(0);
86+
var right = call.operand(1);
87+
88+
left.unparse(writer, leftPrec, getLeftPrec());
89+
writer.keyword("= ANY");
90+
writer.print("(");
91+
right.unparse(writer, getRightPrec(), rightPrec);
92+
writer.print(")");
93+
}
94+
};
7095
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright © 2021 DataSQRL ([email protected])
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package com.datasqrl.function.builtinflink;
17+
18+
import com.datasqrl.function.CalciteFunctionUtil;
19+
import com.datasqrl.function.PgSpecificOperatorTable;
20+
import com.datasqrl.function.translations.PostgresSqlTranslation;
21+
import com.datasqrl.function.translations.SqlTranslation;
22+
import com.google.auto.service.AutoService;
23+
import org.apache.calcite.sql.SqlBasicCall;
24+
import org.apache.calcite.sql.SqlCall;
25+
import org.apache.calcite.sql.SqlNode;
26+
import org.apache.calcite.sql.SqlWriter;
27+
import org.apache.calcite.sql.parser.SqlParserPos;
28+
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
29+
30+
@AutoService(SqlTranslation.class)
31+
public class ArrayContainsSqlTranslation extends PostgresSqlTranslation {
32+
33+
public ArrayContainsSqlTranslation() {
34+
super(CalciteFunctionUtil.lightweightOp(BuiltInFunctionDefinitions.ARRAY_CONTAINS));
35+
}
36+
37+
@Override
38+
public void unparse(SqlCall call, SqlWriter writer, int leftPrec, int rightPrec) {
39+
var rawArray = call.getOperandList().get(0);
40+
var value = call.getOperandList().get(1);
41+
42+
var array = unwrapCast(rawArray);
43+
44+
// Emit: value = ANY(array)
45+
PgSpecificOperatorTable.EqualsAny.createCall(SqlParserPos.ZERO, value, array)
46+
.unparse(writer, leftPrec, rightPrec);
47+
}
48+
49+
// FIXME: Remove unwrapCast method
50+
private SqlNode unwrapCast(SqlNode node) {
51+
if (node instanceof SqlBasicCall) {
52+
SqlBasicCall call = (SqlBasicCall) node;
53+
if (call.getOperator().getName().equalsIgnoreCase("CAST")) {
54+
return call.getOperandList().get(0);
55+
}
56+
}
57+
return node;
58+
}
59+
}

sqrl-server/sqrl-server-vertx-base/src/main/java/com/datasqrl/graphql/auth/AuthMetadataReader.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import com.datasqrl.graphql.server.MetadataReader;
2121
import graphql.schema.DataFetchingEnvironment;
22+
import io.vertx.core.json.JsonArray;
2223
import io.vertx.ext.web.RoutingContext;
2324
import lombok.extern.slf4j.Slf4j;
2425

@@ -49,6 +50,11 @@ public Object read(DataFetchingEnvironment env, String name, boolean isRequired)
4950
checkNotNull(value, "Claim '%s' must not be null", name);
5051
}
5152

53+
if (value instanceof JsonArray array) {
54+
// Unwrap JsonArray to plain Java array to avoid pgclient treating it as JSONB
55+
return array.getList().toArray(new Object[0]);
56+
}
57+
5258
return value;
5359
}
5460
}

sqrl-testing/sqrl-integration-tests/src/test/java/com/datasqrl/FullUseCasesIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ void printUseCaseNumbers(UseCaseTestParameter param) {
339339
public void runTestCaseByName() {
340340
var param =
341341
getSpecificUseCase(
342-
p -> p.sqrlFileName.startsWith("flink_kafka.sqrl") && p.goal.equals("test"));
342+
p -> p.sqrlFileName.startsWith("jwt-authorized.sqrl") && p.goal.equals("test"));
343343

344344
useCase(param);
345345
}

sqrl-testing/sqrl-integration-tests/src/test/resources/snapshots/com/datasqrl/DAGPlannerTest/functionParameterTest.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ INSERT INTO `default_catalog`.`default_database`.`CustomerByNothing_2`
351351
],
352352
"query" : {
353353
"type" : "SqlQuery",
354-
"sql" : "SELECT *\nFROM \"Customer_1\"\nWHERE \"array_contains\"(CAST($1 AS JSONB), \"customerid\")",
354+
"sql" : "SELECT *\nFROM \"Customer_1\"\nWHERE (\"customerid\" = ANY ($1))",
355355
"parameters" : [
356356
{
357357
"type" : "arg",

sqrl-testing/sqrl-integration-tests/src/test/resources/snapshots/com/datasqrl/UseCaseCompileTest/jwt-authorized--package.txt

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,25 @@ SQL: CREATE VIEW AuthMyTable AS
1818
WHERE t.val = ?
1919
ORDER BY t.val ASC;
2020

21+
=== AuthMyTableValues
22+
ID: default_catalog.default_database.AuthMyTableValues
23+
Type: query
24+
Stage: postgres
25+
Inputs: default_catalog.default_database.MyTable
26+
Annotations:
27+
- parameters: vals
28+
- base-table: MyTable
29+
Plan:
30+
LogicalSort(sort0=[$0], dir0=[ASC-nulls-first])
31+
LogicalProject(val=[$0])
32+
LogicalFilter(condition=[array_contains(CAST(?0):BIGINT ARRAY, CAST($0):BIGINT)])
33+
LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
34+
SQL: CREATE VIEW AuthMyTableValues AS
35+
SELECT t.*
36+
FROM MyTable t
37+
WHERE array_contains(cast(? as ARRAY<BIGINT>), t.val)
38+
ORDER BY t.val ASC;
39+
2140
=== MyTable
2241
ID: default_catalog.default_database.MyTable
2342
Type: state
@@ -139,6 +158,39 @@ CREATE TABLE IF NOT EXISTS "MyTable" ("val" INTEGER NOT NULL , PRIMARY KEY ("val
139158
"database" : "POSTGRES"
140159
}
141160
}
161+
},
162+
{
163+
"type" : "args",
164+
"parentType" : "Query",
165+
"fieldName" : "AuthMyTableValues",
166+
"exec" : {
167+
"arguments" : [
168+
{
169+
"type" : "variable",
170+
"path" : "offset"
171+
},
172+
{
173+
"type" : "variable",
174+
"path" : "limit"
175+
}
176+
],
177+
"query" : {
178+
"type" : "SqlQuery",
179+
"sql" : "SELECT *\nFROM (SELECT \"val\"\n FROM \"MyTable\"\n ORDER BY \"val\" NULLS FIRST) AS \"t\"\nWHERE (CAST(\"val\" AS BIGINT) = ANY ($1))\nORDER BY \"val\" NULLS FIRST",
180+
"parameters" : [
181+
{
182+
"type" : "metadata",
183+
"metadata" : {
184+
"metadataType" : "AUTH",
185+
"name" : "values",
186+
"isRequired" : true
187+
}
188+
}
189+
],
190+
"pagination" : "LIMIT_AND_OFFSET",
191+
"database" : "POSTGRES"
192+
}
193+
}
142194
}
143195
],
144196
"mutations" : [ ],
@@ -195,11 +247,37 @@ CREATE TABLE IF NOT EXISTS "MyTable" ("val" INTEGER NOT NULL , PRIMARY KEY ("val
195247
"mcpMethod" : "TOOL",
196248
"restMethod" : "GET",
197249
"uriTemplate" : "queries/AuthMyTable{?offset,limit}"
250+
},
251+
{
252+
"function" : {
253+
"name" : "GetAuthMyTableValues",
254+
"parameters" : {
255+
"type" : "object",
256+
"properties" : {
257+
"offset" : {
258+
"type" : "integer"
259+
},
260+
"limit" : {
261+
"type" : "integer"
262+
}
263+
},
264+
"required" : [ ]
265+
}
266+
},
267+
"format" : "JSON",
268+
"apiQuery" : {
269+
"query" : "query AuthMyTableValues($limit: Int = 10, $offset: Int = 0) {\nAuthMyTableValues(limit: $limit, offset: $offset) {\nval\n}\n\n}",
270+
"queryName" : "AuthMyTableValues",
271+
"operationType" : "QUERY"
272+
},
273+
"mcpMethod" : "TOOL",
274+
"restMethod" : "GET",
275+
"uriTemplate" : "queries/AuthMyTableValues{?offset,limit}"
198276
}
199277
],
200278
"schema" : {
201279
"type" : "string",
202-
"schema" : "\"An RFC-3339 compliant Full Date Scalar\"\nscalar Date\n\n\"A slightly refined version of RFC-3339 compliant DateTime Scalar\"\nscalar DateTime\n\n\"A JSON scalar\"\nscalar JSON\n\n\"24-hour clock time value string in the format `hh:mm:ss` or `hh:mm:ss.sss`.\"\nscalar LocalTime\n\n\"A 64-bit signed integer\"\nscalar Long\n\ntype MyTable {\n val: Int!\n}\n\ntype Query {\n MyTable(limit: Int = 10, offset: Int = 0): [MyTable!]\n AuthMyTable(limit: Int = 10, offset: Int = 0): [MyTable!]\n}\n\nenum _McpMethodType {\n NONE\n TOOL\n RESOURCE\n}\n\nenum _RestMethodType {\n NONE\n GET\n POST\n}\n\ndirective @api(mcp: _McpMethodType, rest: _RestMethodType, uri: String) on QUERY | MUTATION | FIELD_DEFINITION\n"
280+
"schema" : "\"An RFC-3339 compliant Full Date Scalar\"\nscalar Date\n\n\"A slightly refined version of RFC-3339 compliant DateTime Scalar\"\nscalar DateTime\n\n\"A JSON scalar\"\nscalar JSON\n\n\"24-hour clock time value string in the format `hh:mm:ss` or `hh:mm:ss.sss`.\"\nscalar LocalTime\n\n\"A 64-bit signed integer\"\nscalar Long\n\ntype MyTable {\n val: Int!\n}\n\ntype Query {\n MyTable(limit: Int = 10, offset: Int = 0): [MyTable!]\n AuthMyTable(limit: Int = 10, offset: Int = 0): [MyTable!]\n AuthMyTableValues(limit: Int = 10, offset: Int = 0): [MyTable!]\n}\n\nenum _McpMethodType {\n NONE\n TOOL\n RESOURCE\n}\n\nenum _RestMethodType {\n NONE\n GET\n POST\n}\n\ndirective @api(mcp: _McpMethodType, rest: _RestMethodType, uri: String) on QUERY | MUTATION | FIELD_DEFINITION\n"
203281
}
204282
}
205283
}
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#!/usr/bin/env python3
2+
"""
3+
JWT Token Generation Utility for DataSQRL JWT Authorization Tests
4+
5+
This utility generates JWT tokens for testing JWT-based authorization in DataSQRL.
6+
The tokens are compatible with the jwt-authorized test suite configuration.
7+
8+
REQUIREMENTS:
9+
pip install PyJWT
10+
11+
USAGE:
12+
1. Edit the PAYLOAD variable below to set the desired JWT claims
13+
2. Run: python3 generate_jwt_tokens.py
14+
3. Copy the generated token to your test .properties file
15+
16+
JWT CONFIGURATION:
17+
The tokens are generated using the secret key and parameters from package.json:
18+
- Secret: testSecretThatIsAtLeast256BitsLong32Chars (base64: dGVzdFNlY3JldFRoYXRJc0F0TGVhc3QyNTZCaXRzTG9uZzMyQ2hhcnM=)
19+
- Algorithm: HS256
20+
- Issuer: my-test-issuer
21+
- Audience: ["my-test-audience"]
22+
- Default Expiration: 9999999999 (far future)
23+
24+
SQRL FUNCTIONS:
25+
- AuthMyTable(val BIGINT): Expects JWT claim "val" with single integer value
26+
- AuthMyTableValues(val ARRAY<BIGINT>): Expects JWT claim "values" with array of integers
27+
28+
EXAMPLES:
29+
# For AuthMyTable with val=73, edit PAYLOAD to:
30+
# PAYLOAD = {
31+
# "iss": "my-test-issuer",
32+
# "aud": ["my-test-audience"],
33+
# "exp": 9999999999,
34+
# "val": 73
35+
# }
36+
37+
# For AuthMyTableValues with values=[1,2,82], edit PAYLOAD to:
38+
# PAYLOAD = {
39+
# "iss": "my-test-issuer",
40+
# "aud": ["my-test-audience"],
41+
# "exp": 9999999999,
42+
# "values": [1, 2, 82]
43+
# }
44+
"""
45+
46+
import jwt
47+
import json
48+
49+
# JWT configuration matching package.json
50+
SECRET = "testSecretThatIsAtLeast256BitsLong32Chars"
51+
ALGORITHM = "HS256"
52+
53+
# ==============================================================================
54+
# EDIT THIS PAYLOAD TO GENERATE DIFFERENT TOKENS
55+
# ==============================================================================
56+
# Change the payload below and run this script to generate a new JWT token
57+
PAYLOAD = {
58+
"iss": "my-test-issuer",
59+
"aud": ["my-test-audience"],
60+
"exp": 9999999999,
61+
"values": [1, 2, 82] # Example: for AuthMyTableValues with values 1, 2, 82
62+
# "val": 73 # Example: for AuthMyTable with single value 73
63+
# "val": None # Example: for null value
64+
# (omit val/values) # Example: for no-val scenarios
65+
}
66+
# ==============================================================================
67+
68+
69+
def generate_token():
70+
"""Generate JWT token using the hardcoded PAYLOAD."""
71+
token = jwt.encode(PAYLOAD, SECRET, algorithm=ALGORITHM)
72+
return token
73+
74+
75+
def decode_token(token):
76+
"""Decode and verify JWT token."""
77+
try:
78+
decoded = jwt.decode(token, SECRET, algorithms=[ALGORITHM],
79+
audience=PAYLOAD["aud"], issuer=PAYLOAD["iss"])
80+
return decoded
81+
except jwt.InvalidTokenError as e:
82+
return f"Invalid token: {e}"
83+
84+
85+
def main():
86+
"""Generate token using the hardcoded PAYLOAD and display results."""
87+
print("=" * 80)
88+
print("JWT Token Generator for DataSQRL JWT Authorization Tests")
89+
print("=" * 80)
90+
91+
print("\nCurrent Payload:")
92+
print(json.dumps(PAYLOAD, indent=2))
93+
94+
token = generate_token()
95+
96+
print(f"\nGenerated JWT Token:")
97+
print(token)
98+
99+
print(f"\nUse in .properties file:")
100+
print(f"Authorization: Bearer {token}")
101+
102+
# Verify by decoding
103+
print(f"\nVerification (decoded payload):")
104+
decoded = decode_token(token)
105+
if isinstance(decoded, dict):
106+
print(json.dumps(decoded, indent=2))
107+
else:
108+
print(f"Error: {decoded}")
109+
110+
print("\n" + "=" * 80)
111+
print("To generate different tokens:")
112+
print("1. Edit the PAYLOAD variable in this script")
113+
print("2. Run this script again")
114+
print("=" * 80)
115+
116+
117+
if __name__ == "__main__":
118+
main()

sqrl-testing/sqrl-integration-tests/src/test/resources/usecases/jwt-authorized/jwt-authorized.sqrl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,9 @@ AuthMyTable(val BIGINT NOT NULL METADATA FROM 'auth.val') :=
1010
FROM MyTable t
1111
WHERE t.val = :val
1212
ORDER BY t.val ASC;
13+
14+
AuthMyTableValues(vals ARRAY<BIGINT> NOT NULL METADATA FROM 'auth.values') :=
15+
SELECT t.*
16+
FROM MyTable t
17+
WHERE array_contains(cast(:vals as ARRAY<BIGINT>), t.val)
18+
ORDER BY t.val ASC;
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"data" : {
3+
"AuthMyTableValues" : [ {
4+
"val" : 1
5+
}, {
6+
"val" : 2
7+
} ]
8+
}
9+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
query {
2+
AuthMyTableValues {
3+
val
4+
}
5+
}

0 commit comments

Comments
 (0)