Skip to content

Commit 7141bfe

Browse files
committed
Add field and table name escaping for python SqlServer connector
1 parent 082e28e commit 7141bfe

2 files changed

Lines changed: 85 additions & 35 deletions

File tree

python/semantic_kernel/connectors/sql_server.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,15 @@ def append_list(self, strings: Sequence[str], sep: str = ", ", suffix: str | Non
154154
self.append(string, suffix=sep)
155155
self.append(strings[-1], suffix=suffix)
156156

157+
@staticmethod
158+
def escape_identifier(identifier: str) -> str:
159+
"""Escape and bracket-quote a SQL Server identifier.
160+
161+
Escapes `]` to `]]` and wraps in square brackets, matching
162+
the SQL Server standard for delimited identifiers.
163+
"""
164+
return f"[{identifier.replace(']', ']]')}]"
165+
157166
def append_table_name(
158167
self, schema: str, table_name: str, prefix: str = "", suffix: str | None = None, newline: bool = False
159168
) -> None:
@@ -169,7 +178,9 @@ def append_table_name(
169178
suffix: Optional suffix to add after the table name.
170179
newline: Whether to add a newline after the table name or suffix.
171180
"""
172-
self.append(f"{prefix} [{schema}].[{table_name}] {suffix or ''}", suffix="\n" if newline else "")
181+
escaped_schema = self.escape_identifier(schema)
182+
escaped_table = self.escape_identifier(table_name)
183+
self.append(f"{prefix} {escaped_schema}.{escaped_table} {suffix or ''}", suffix="\n" if newline else "")
173184

174185
def remove_last(self, number_of_chars: int):
175186
"""Remove the last number_of_chars from the StringBuilder."""
@@ -631,14 +642,14 @@ def parse(node: ast.AST) -> str:
631642
raise VectorStoreOperationException(
632643
f"Field '{node.attr}' not in data model (storage property names are used)."
633644
)
634-
return f"[{node.attr}]"
645+
return QueryBuilder.escape_identifier(node.attr)
635646
case ast.Name():
636647
# Only allow names that are in the data model
637648
if node.id not in self.definition.storage_names:
638649
raise VectorStoreOperationException(
639650
f"Field '{node.id}' not in data model (storage property names are used)."
640651
)
641-
return f"[{node.id}]"
652+
return QueryBuilder.escape_identifier(node.id)
642653
case ast.Constant():
643654
value = node.value
644655
if isinstance(value, (str, int, float, bool, bytes)) or value is None:
@@ -868,12 +879,15 @@ def _build_create_table_query(
868879
with command.query.in_parenthesis(suffix=";"):
869880
# add the key field
870881
command.query.append(
871-
f'"{key_field.storage_name or key_field.name}" '
882+
f"{QueryBuilder.escape_identifier(key_field.storage_name or key_field.name)} "
872883
f"{_python_type_to_sql(key_field.type_, is_key=True)} NOT NULL,\n"
873884
)
874885
# add the data fields
875886
[
876-
command.query.append(f'"{field.storage_name or field.name}" {_python_type_to_sql(field.type_)} NULL,\n')
887+
command.query.append(
888+
f"{QueryBuilder.escape_identifier(field.storage_name or field.name)}"
889+
f" {_python_type_to_sql(field.type_)} NULL,\n"
890+
)
877891
for field in data_fields
878892
]
879893
# add the vector fields
@@ -882,10 +896,11 @@ def _build_create_table_query(
882896
raise VectorStoreOperationException(
883897
f"Index kind '{field.index_kind}' is not supported for field '{field.name}'"
884898
)
885-
command.query.append(f'"{field.storage_name or field.name}" VECTOR({field.dimensions}) NULL,\n')
899+
escaped_name = QueryBuilder.escape_identifier(field.storage_name or field.name)
900+
command.query.append(f"{escaped_name} VECTOR({field.dimensions}) NULL,\n")
886901
# set the primary key
887902
with command.query.in_parenthesis("PRIMARY KEY", "\n"):
888-
command.query.append(key_field.name)
903+
command.query.append(QueryBuilder.escape_identifier(key_field.storage_name or key_field.name))
889904
return command
890905

891906

@@ -954,9 +969,11 @@ def _add_field_names(
954969
"""
955970
fields = chain([key_field], data_fields, vector_fields or [])
956971
if table_identifier:
957-
strings = [f"{table_identifier}.{field.storage_name or field.name}" for field in fields]
972+
strings = [
973+
f"{table_identifier}.{QueryBuilder.escape_identifier(field.storage_name or field.name)}" for field in fields
974+
]
958975
else:
959-
strings = [field.storage_name or field.name for field in fields]
976+
strings = [QueryBuilder.escape_identifier(field.storage_name or field.name) for field in fields]
960977
command.query.append_list(strings)
961978

962979

@@ -997,15 +1014,15 @@ def _build_merge_query(
9971014
_add_field_names(command, key_field, data_fields, vector_fields)
9981015
# add the ON clause
9991016
with command.query.in_parenthesis("ON", "\n"):
1000-
command.query.append(
1001-
f"t.{key_field.storage_name or key_field.name} = s.{key_field.storage_name or key_field.name}"
1002-
)
1017+
escaped_key = QueryBuilder.escape_identifier(key_field.storage_name or key_field.name)
1018+
command.query.append(f"t.{escaped_key} = s.{escaped_key}")
10031019
# Set the Matched clause
10041020
command.query.append("WHEN MATCHED THEN\n")
10051021
command.query.append("UPDATE SET ")
10061022
command.query.append_list(
10071023
[
1008-
f"t.{field.storage_name or field.name} = s.{field.storage_name or field.name}"
1024+
f"t.{QueryBuilder.escape_identifier(field.storage_name or field.name)}"
1025+
f" = s.{QueryBuilder.escape_identifier(field.storage_name or field.name)}"
10091026
for field in chain(data_fields, vector_fields)
10101027
],
10111028
suffix="\n",
@@ -1018,7 +1035,8 @@ def _build_merge_query(
10181035
with command.query.in_parenthesis("VALUES", " \n"):
10191036
_add_field_names(command, key_field, data_fields, vector_fields, table_identifier="s")
10201037
# add the closing parenthesis
1021-
command.query.append(f"OUTPUT inserted.{key_field.name} INTO @UpsertedKeys (KeyColumn);\n")
1038+
escaped_key_out = QueryBuilder.escape_identifier(key_field.storage_name or key_field.name)
1039+
command.query.append(f"OUTPUT inserted.{escaped_key_out} INTO @UpsertedKeys (KeyColumn);\n")
10221040
command.query.append("SELECT KeyColumn FROM @UpsertedKeys;\n")
10231041
return command
10241042

@@ -1041,7 +1059,7 @@ def _build_select_query(
10411059
command.query.append_table_name(schema, table, prefix=" FROM", newline=True)
10421060
# add the WHERE clause
10431061
if keys:
1044-
command.query.append(f"WHERE {key_field.storage_name or key_field.name} IN\n")
1062+
command.query.append(f"WHERE {QueryBuilder.escape_identifier(key_field.storage_name or key_field.name)} IN\n")
10451063
with command.query.in_parenthesis():
10461064
# add the keys
10471065
command.query.append_list(["?"] * len(keys))
@@ -1061,7 +1079,7 @@ def _build_delete_query(
10611079
# start the DELETE statement
10621080
command.query.append_table_name(schema, table)
10631081
# add the WHERE clause
1064-
command.query.append(f"WHERE [{key_field.storage_name or key_field.name}] IN")
1082+
command.query.append(f"WHERE {QueryBuilder.escape_identifier(key_field.storage_name or key_field.name)} IN")
10651083
with command.query.in_parenthesis():
10661084
# add the keys
10671085
command.query.append_list(["?"] * len(keys))
@@ -1109,7 +1127,7 @@ def _build_search_query(
11091127
asc = DISTANCE_FUNCTION_DIRECTION_HELPER[vector_field.distance_function](0, 1)
11101128

11111129
command.query.append(
1112-
f", VECTOR_DISTANCE('{distance_function}', {vector_field.storage_name or vector_field.name}, CAST(? AS VECTOR({vector_field.dimensions}))) as {SCORE_FIELD_NAME}\n", # noqa: E501
1130+
f", VECTOR_DISTANCE('{distance_function}', {QueryBuilder.escape_identifier(vector_field.storage_name or vector_field.name)}, CAST(? AS VECTOR({vector_field.dimensions}))) as {SCORE_FIELD_NAME}\n", # noqa: E501
11131131
)
11141132
command.add_parameter(_cast_value(vector))
11151133
# add the FROM clause

python/tests/unit/connectors/memory/test_sql_server.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,32 @@ def test_query_builder_append_list(self):
4242
result = str(qb).strip()
4343
assert result == "id, name, age;"
4444

45+
def test_query_builder_escape_identifier(self):
46+
assert QueryBuilder.escape_identifier("simple") == "[simple]"
47+
assert QueryBuilder.escape_identifier("has]bracket") == "[has]]bracket]"
48+
assert QueryBuilder.escape_identifier("two]]brackets") == "[two]]]]brackets]"
49+
assert QueryBuilder.escape_identifier("") == "[]"
50+
4551
def test_query_builder_append_table_name(self):
4652
qb = QueryBuilder()
4753
qb.append_table_name("dbo", "Users", prefix="SELECT * FROM", suffix=";", newline=False)
4854
result = str(qb).strip()
4955
assert result == "SELECT * FROM [dbo].[Users] ;"
5056

57+
def test_query_builder_append_table_name_escapes_closing_bracket(self):
58+
qb = QueryBuilder()
59+
qb.append_table_name("my]schema", "my]table", prefix="SELECT * FROM", suffix=";")
60+
result = str(qb).strip()
61+
assert result == "SELECT * FROM [my]]schema].[my]]table] ;"
62+
63+
def test_query_builder_append_table_name_prevents_sql_injection(self):
64+
qb = QueryBuilder()
65+
qb.append("DROP TABLE IF EXISTS")
66+
qb.append_table_name("dbo", "]; EXEC xp_cmdshell('whoami'); --", suffix=";")
67+
result = str(qb)
68+
assert "EXEC xp_cmdshell" not in result.split("].[")[0], "SQL injection should not escape bracket quoting"
69+
assert "[dbo].[]]; EXEC xp_cmdshell('whoami'); --]" in result
70+
5171
def test_query_builder_remove_last(self):
5272
qb = QueryBuilder("SELECT * FROM table;")
5373
qb.remove_last(1) # remove trailing semicolon
@@ -121,8 +141,8 @@ def test_build_create_table_query(self):
121141
cmd_str = str(cmd.query)
122142
assert (
123143
cmd_str
124-
== 'BEGIN\nCREATE TABLE [dbo].[Test] \n ("id" nvarchar(255) NOT NULL,\n"name" nvarchar(max) NULL,\n"age" '
125-
'int NULL,\n"embedding" VECTOR(1536) NULL,\nPRIMARY KEY (id) \n) ;\nEND\n'
144+
== "BEGIN\nCREATE TABLE [dbo].[Test] \n ([id] nvarchar(255) NOT NULL,\n[name] nvarchar(max) NULL,\n[age] "
145+
"int NULL,\n[embedding] VECTOR(1536) NULL,\nPRIMARY KEY ([id]) \n) ;\nEND\n"
126146
)
127147

128148
def test_delete_table_query(self):
@@ -170,11 +190,17 @@ def test_build_merge_query(self):
170190
assert cmd.parameters[3] == json.dumps(records[0]["embedding"])
171191
str_cmd = str(cmd)
172192
assert str_cmd == (
173-
"DECLARE @UpsertedKeys TABLE (KeyColumn nvarchar(255));\nMERGE INTO [dbo].[Test] AS t\nUSING ( "
174-
"VALUES (?, ?, ?, ?) ) AS s (id, name, age, embedding) ON (t.id = s.id) \nWHEN MATCHED THEN\nUPDATE "
175-
"SET t.name = s.name, t.age = s.age, t.embedding = s.embedding\nWHEN NOT MATCHED THEN\nINSERT "
176-
"(id, name, age, embedding) VALUES (s.id, s.name, s.age, s.embedding) \nOUTPUT inserted.id "
177-
"INTO @UpsertedKeys (KeyColumn);\nSELECT KeyColumn FROM @UpsertedKeys;\n"
193+
"DECLARE @UpsertedKeys TABLE (KeyColumn nvarchar(255));\n"
194+
"MERGE INTO [dbo].[Test] AS t\n"
195+
"USING ( VALUES (?, ?, ?, ?) ) AS s ([id], [name], [age], [embedding]) "
196+
"ON (t.[id] = s.[id]) \n"
197+
"WHEN MATCHED THEN\n"
198+
"UPDATE SET t.[name] = s.[name], t.[age] = s.[age], t.[embedding] = s.[embedding]\n"
199+
"WHEN NOT MATCHED THEN\n"
200+
"INSERT ([id], [name], [age], [embedding]) "
201+
"VALUES (s.[id], s.[name], s.[age], s.[embedding]) \n"
202+
"OUTPUT inserted.[id] INTO @UpsertedKeys (KeyColumn);\n"
203+
"SELECT KeyColumn FROM @UpsertedKeys;\n"
178204
)
179205

180206
def test_build_select_query(self):
@@ -192,7 +218,7 @@ def test_build_select_query(self):
192218
cmd = _build_select_query(schema, table, key_field, data_fields, vector_fields, keys)
193219
assert cmd.parameters == ["test"]
194220
str_cmd = str(cmd)
195-
assert str_cmd == "SELECT\nid, name, age, embedding FROM [dbo].[Test] \nWHERE id IN\n (?) ;"
221+
assert str_cmd == "SELECT\n[id], [name], [age], [embedding] FROM [dbo].[Test] \nWHERE [id] IN\n (?) ;"
196222

197223
def test_build_delete_query(self):
198224
schema = "dbo"
@@ -230,7 +256,7 @@ def test_build_search_query(self):
230256
assert cmd.parameters[0] == json.dumps(vector)
231257
str_cmd = str(cmd)
232258
assert (
233-
str_cmd == "SELECT id, name, age, VECTOR_DISTANCE('cosine', embedding, CAST(? AS VECTOR(5))) as "
259+
str_cmd == "SELECT [id], [name], [age], VECTOR_DISTANCE('cosine', [embedding], CAST(? AS VECTOR(5))) as "
234260
"_vector_distance_value\n FROM [dbo].[Test] \nORDER BY "
235261
"_vector_distance_value ASC\nOFFSET 0 ROWS FETCH NEXT 3 ROWS ONLY;"
236262
)
@@ -378,11 +404,17 @@ async def test_upsert(
378404
await collection.upsert(record)
379405
mock_connection.cursor.return_value.__enter__.return_value.execute.assert_called_with(
380406
(
381-
"DECLARE @UpsertedKeys TABLE (KeyColumn nvarchar(255));\nMERGE INTO [dbo].[test] AS t\nUSING ( VALUES"
382-
" (?, ?, ?) ) AS s (id, content, vector) ON (t.id = s.id) \nWHEN MATCHED THEN\nUPDATE SET t.content"
383-
" = s.content, t.vector = s.vector\nWHEN NOT MATCHED THEN\nINSERT (id, content, vector) VALUES (s.id, "
384-
"s.content, s.vector) \nOUTPUT inserted.id INTO @UpsertedKeys (KeyColumn);\nSELECT KeyColumn "
385-
"FROM @UpsertedKeys;\n"
407+
"DECLARE @UpsertedKeys TABLE (KeyColumn nvarchar(255));\n"
408+
"MERGE INTO [dbo].[test] AS t\n"
409+
"USING ( VALUES (?, ?, ?) ) AS s ([id], [content], [vector]) "
410+
"ON (t.[id] = s.[id]) \n"
411+
"WHEN MATCHED THEN\n"
412+
"UPDATE SET t.[content] = s.[content], t.[vector] = s.[vector]\n"
413+
"WHEN NOT MATCHED THEN\n"
414+
"INSERT ([id], [content], [vector]) "
415+
"VALUES (s.[id], s.[content], s.[vector]) \n"
416+
"OUTPUT inserted.[id] INTO @UpsertedKeys (KeyColumn);\n"
417+
"SELECT KeyColumn FROM @UpsertedKeys;\n"
386418
),
387419
("1", "test", json.dumps([0.1, 0.2, 0.3, 0.4, 0.5])),
388420
)
@@ -415,7 +447,7 @@ class MockRow(NamedTuple):
415447
mock_cursor.__iter__.return_value = [row]
416448
record = await collection.get(key, include_vectors=True)
417449
mock_cursor.execute.assert_called_with(
418-
"SELECT\nid, content, vector FROM [dbo].[test] \nWHERE id IN\n (?) ;", ("1",)
450+
"SELECT\n[id], [content], [vector] FROM [dbo].[test] \nWHERE [id] IN\n (?) ;", ("1",)
419451
)
420452
assert record["id"] == "1"
421453
assert record["content"] == "test"
@@ -478,7 +510,7 @@ class MockRow:
478510
assert record.score == 0.1
479511
mock_cursor.execute.assert_called_with(
480512
(
481-
"SELECT id, content, VECTOR_DISTANCE('cosine', vector, CAST(? AS VECTOR(5))) as "
513+
"SELECT [id], [content], VECTOR_DISTANCE('cosine', [vector], CAST(? AS VECTOR(5))) as "
482514
"_vector_distance_value\n FROM [dbo].[test] \n WHERE [content] = ? \nORDER BY _vector_distance_value "
483515
"ASC\nOFFSET 0 ROWS FETCH NEXT 3 ROWS ONLY;"
484516
),
@@ -502,8 +534,8 @@ async def test_ensure_collection_exists(
502534
await collection.ensure_collection_exists()
503535
mock_connection.cursor.return_value.__enter__.return_value.execute.assert_called_with(
504536
(
505-
"IF OBJECT_ID(N' [dbo].[test] ', N'U') IS NULL\nBEGIN\nCREATE TABLE [dbo].[test] \n (\"id\" nvarchar"
506-
'(255) NOT NULL,\n"content" nvarchar(max) NULL,\n"vector" VECTOR(5) NULL,\nPRIMARY KEY (id) \n) ;'
537+
"IF OBJECT_ID(N' [dbo].[test] ', N'U') IS NULL\nBEGIN\nCREATE TABLE [dbo].[test] \n ([id] nvarchar"
538+
"(255) NOT NULL,\n[content] nvarchar(max) NULL,\n[vector] VECTOR(5) NULL,\nPRIMARY KEY ([id]) \n) ;"
507539
"\nEND\n"
508540
),
509541
(),

0 commit comments

Comments
 (0)