Skip to content

Commit f41d79c

Browse files
authored
fix: secure sql queries generated by leaderboard stats (#611)
1 parent af64385 commit f41d79c

File tree

6 files changed

+404
-141
lines changed

6 files changed

+404
-141
lines changed
Lines changed: 60 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,31 @@
11
import _ from 'lodash';
22
import { Schemas } from 'forest-express';
33
import Orm from '../utils/orm';
4+
import { InvalidParameterError } from './errors';
45

6+
function getAggregateField({
7+
aggregateField, schemaRelationship, modelRelationship,
8+
}) {
9+
// NOTICE: As MySQL cannot support COUNT(table_name.*) syntax, fieldName cannot be '*'.
10+
const fieldName = aggregateField
11+
|| schemaRelationship.primaryKeys[0]
12+
|| schemaRelationship.fields[0].field;
13+
return `${modelRelationship.name}.${Orm.getColumnName(schemaRelationship, fieldName)}`;
14+
}
15+
16+
/**
17+
* @param {import('sequelize').Model} model
18+
* @param {import('sequelize').Model} modelRelationship
19+
* @param {{
20+
* label_field: string;
21+
* aggregate: string;
22+
* aggregate_field: string;
23+
* }} params
24+
* @param {*} options
25+
*/
526
function LeaderboardStatGetter(model, modelRelationship, params, options) {
627
const labelField = params.label_field;
728
const aggregate = params.aggregate.toUpperCase();
8-
const aggregateField = params.aggregate_field;
929
const { limit } = params;
1030
const schema = Schemas.schemas[model.name];
1131
const schemaRelationship = Schemas.schemas[modelRelationship.name];
@@ -15,50 +35,52 @@ function LeaderboardStatGetter(model, modelRelationship, params, options) {
1535
(association) => association.target.name === model.name,
1636
);
1737

18-
if (associationFound && associationFound.as) {
19-
associationAs = associationFound.as;
20-
}
21-
22-
const groupBy = `"${associationAs}"."${labelField}"`;
38+
const aggregateField = getAggregateField({
39+
aggregateField: params.aggregate_field,
40+
schemaRelationship,
41+
modelRelationship,
42+
});
2343

24-
function getAggregateField() {
25-
// NOTICE: As MySQL cannot support COUNT(table_name.*) syntax, fieldName cannot be '*'.
26-
const fieldName = aggregateField
27-
|| schemaRelationship.primaryKeys[0]
28-
|| schemaRelationship.fields[0].field;
29-
return `"${modelRelationship.tableName}"."${Orm.getColumnName(schema, fieldName)}"`;
44+
if (!associationFound) {
45+
throw new InvalidParameterError(`Association ${model.name} not found`);
3046
}
3147

32-
let joinQuery;
33-
if (associationFound.associationType === 'BelongsToMany') {
34-
const joinTableName = associationFound.through.model.tableName;
35-
joinQuery = `INNER JOIN "${joinTableName}"
36-
ON "${modelRelationship.tableName}"."${associationFound.sourceKeyField}" = "${joinTableName}"."${associationFound.foreignKey}"
37-
INNER JOIN "${model.tableName}" AS "${associationAs}"
38-
ON "${associationAs}"."${associationFound.targetKeyField}" = "${joinTableName}"."${associationFound.otherKey}"
39-
`;
40-
} else {
41-
const foreignKeyField = associationFound.source
42-
.rawAttributes[associationFound.foreignKey].field;
43-
joinQuery = `INNER JOIN "${model.tableName}" AS "${associationAs}"
44-
ON "${associationAs}"."${associationFound.targetKeyField}" = "${modelRelationship.tableName}"."${foreignKeyField}"
45-
`;
48+
if (associationFound.as) {
49+
associationAs = associationFound.as;
4650
}
4751

48-
const query = `
49-
SELECT ${aggregate}(${getAggregateField()}) as "value", ${groupBy} as "key"
50-
FROM "${modelRelationship.tableName}"
51-
${joinQuery}
52-
GROUP BY ${groupBy}
53-
ORDER BY "value" DESC
54-
LIMIT ${limit}
55-
`;
52+
const labelColumn = Orm.getColumnName(schema, labelField);
53+
const groupBy = `${associationAs}.${labelColumn}`;
5654

55+
this.perform = async () => {
56+
const records = await modelRelationship
57+
.unscoped()
58+
.findAll({
59+
attributes: [
60+
[options.sequelize.col(groupBy), 'key'],
61+
[options.sequelize.fn(aggregate, options.sequelize.col(aggregateField)), 'value'],
62+
],
63+
includeIgnoreAttributes: false,
64+
include: [{
65+
model,
66+
attributes: [labelField],
67+
as: associationAs,
68+
required: true,
69+
}],
70+
subQuery: false,
71+
group: groupBy,
72+
order: [[options.sequelize.literal('value'), 'DESC']],
73+
limit,
74+
raw: true,
75+
});
5776

58-
this.perform = () => options.connections[0].query(query, {
59-
type: model.sequelize.QueryTypes.SELECT,
60-
})
61-
.then((records) => ({ value: records }));
77+
return {
78+
value: records.map((data) => ({
79+
key: data.key,
80+
value: Number(data.value),
81+
})),
82+
};
83+
};
6284
}
6385

6486
module.exports = LeaderboardStatGetter;

test/databases.js

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
const Sequelize = require('sequelize');
22

3+
/** @typedef {ConnectionManager} ConnectionManager */
34
class ConnectionManager {
4-
constructor(connectionString) {
5+
constructor(dialect, connectionString) {
6+
this.dialect = dialect;
57
this.connectionString = connectionString;
68
this.databaseOptions = {
79
logging: false,
@@ -11,13 +13,16 @@ class ConnectionManager {
1113
}
1214

1315
getDialect() {
14-
return this.connection && this.connection.options && this.connection.options.dialect;
16+
return this.dialect;
1517
}
1618

1719
getPort() {
1820
return this.connection && this.connection.options && this.connection.options.port;
1921
}
2022

23+
/**
24+
* @returns {import('sequelize').Sequelize}
25+
*/
2126
createConnection() {
2227
if (!this.connection) {
2328
this.connection = new Sequelize(this.connectionString, this.databaseOptions);
@@ -33,8 +38,11 @@ class ConnectionManager {
3338
}
3439
}
3540

41+
/**
42+
* @type {Record<string, ConnectionManager>}
43+
*/
3644
module.exports = {
37-
sequelizePostgres: new ConnectionManager('postgres://forest:secret@localhost:5436/forest-express-sequelize-test'),
38-
sequelizeMySQLMin: new ConnectionManager('mysql://forest:secret@localhost:8998/forest-express-sequelize-test'),
39-
sequelizeMySQLMax: new ConnectionManager('mysql://forest:secret@localhost:8999/forest-express-sequelize-test'),
45+
sequelizePostgres: new ConnectionManager('Postgresql 9.4', 'postgres://forest:secret@localhost:5436/forest-express-sequelize-test'),
46+
sequelizeMySQLMin: new ConnectionManager('MySQL 5.6', 'mysql://forest:secret@localhost:8998/forest-express-sequelize-test'),
47+
sequelizeMySQLMax: new ConnectionManager('MySQL 8.0', 'mysql://forest:secret@localhost:8999/forest-express-sequelize-test'),
4048
};
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
[{
2+
"model": "theVendors",
3+
"data": {
4+
"id": 100,
5+
"firstName": "Alice",
6+
"lastName": "Doe"
7+
}
8+
},{
9+
"model": "theVendors",
10+
"data": {
11+
"id": 101,
12+
"firstName": "Bob",
13+
"lastName": "Doe"
14+
}
15+
},{
16+
"model": "theCustomers",
17+
"data": {
18+
"id": 100,
19+
"name": "big customer",
20+
"objectiveScore": 5
21+
}
22+
},{
23+
"model": "theCustomers",
24+
"data": {
25+
"id": 101,
26+
"name": "small customer",
27+
"objectiveScore": 1
28+
}
29+
},{
30+
"model": "theirSales",
31+
"data": {
32+
"id": 100,
33+
"vendorId": 100,
34+
"customerId": 100,
35+
"sellingAmount": 100
36+
}
37+
},{
38+
"model": "theirSales",
39+
"data": {
40+
"id": 101,
41+
"vendorId": 101,
42+
"customerId": 100,
43+
"sellingAmount": 200
44+
}
45+
},{
46+
"model": "theirSales",
47+
"data": {
48+
"id": 102,
49+
"vendorId": 100,
50+
"customerId": 101,
51+
"sellingAmount": 150
52+
}
53+
}]
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
/**
2+
* @param {import("../databases").ConnectionManager} connectionManager
3+
* @param {(connection: import("sequelize").Sequelize) => Promise<void>} testCallback
4+
*/
5+
async function runWithConnection(connectionManager, testCallback) {
6+
try {
7+
await testCallback(connectionManager.createConnection());
8+
} finally {
9+
connectionManager.closeConnection();
10+
}
11+
}
12+
13+
module.exports = runWithConnection;

0 commit comments

Comments
 (0)