@@ -42,12 +42,32 @@ def test_query_builder_append_list(self):
4242 result = str (qb ).strip ()
4343 assert result == "id, name, age;"
4444
45+ def test_query_builder_escape_identifier (self ):
46+ assert QueryBuilder .escape_identifier ("simple" ) == "[simple]"
47+ assert QueryBuilder .escape_identifier ("has]bracket" ) == "[has]]bracket]"
48+ assert QueryBuilder .escape_identifier ("two]]brackets" ) == "[two]]]]brackets]"
49+ assert QueryBuilder .escape_identifier ("" ) == "[]"
50+
4551 def test_query_builder_append_table_name (self ):
4652 qb = QueryBuilder ()
4753 qb .append_table_name ("dbo" , "Users" , prefix = "SELECT * FROM" , suffix = ";" , newline = False )
4854 result = str (qb ).strip ()
4955 assert result == "SELECT * FROM [dbo].[Users] ;"
5056
57+ def test_query_builder_append_table_name_escapes_closing_bracket (self ):
58+ qb = QueryBuilder ()
59+ qb .append_table_name ("my]schema" , "my]table" , prefix = "SELECT * FROM" , suffix = ";" )
60+ result = str (qb ).strip ()
61+ assert result == "SELECT * FROM [my]]schema].[my]]table] ;"
62+
63+ def test_query_builder_append_table_name_prevents_sql_injection (self ):
64+ qb = QueryBuilder ()
65+ qb .append ("DROP TABLE IF EXISTS" )
66+ qb .append_table_name ("dbo" , "]; EXEC xp_cmdshell('whoami'); --" , suffix = ";" )
67+ result = str (qb )
68+ assert "EXEC xp_cmdshell" not in result .split ("].[" )[0 ], "SQL injection should not escape bracket quoting"
69+ assert "[dbo].[]]; EXEC xp_cmdshell('whoami'); --]" in result
70+
5171 def test_query_builder_remove_last (self ):
5272 qb = QueryBuilder ("SELECT * FROM table;" )
5373 qb .remove_last (1 ) # remove trailing semicolon
@@ -121,8 +141,8 @@ def test_build_create_table_query(self):
121141 cmd_str = str (cmd .query )
122142 assert (
123143 cmd_str
124- == ' BEGIN\n CREATE TABLE [dbo].[Test] \n ("id" nvarchar(255) NOT NULL,\n " name" nvarchar(max) NULL,\n " age" '
125- ' int NULL,\n " embedding" VECTOR(1536) NULL,\n PRIMARY KEY (id ) \n ) ;\n END\n '
144+ == " BEGIN\n CREATE TABLE [dbo].[Test] \n ([id] nvarchar(255) NOT NULL,\n [ name] nvarchar(max) NULL,\n [ age] "
145+ " int NULL,\n [ embedding] VECTOR(1536) NULL,\n PRIMARY KEY ([id] ) \n ) ;\n END\n "
126146 )
127147
128148 def test_delete_table_query (self ):
@@ -170,11 +190,17 @@ def test_build_merge_query(self):
170190 assert cmd .parameters [3 ] == json .dumps (records [0 ]["embedding" ])
171191 str_cmd = str (cmd )
172192 assert str_cmd == (
173- "DECLARE @UpsertedKeys TABLE (KeyColumn nvarchar(255));\n MERGE INTO [dbo].[Test] AS t\n USING ( "
174- "VALUES (?, ?, ?, ?) ) AS s (id, name, age, embedding) ON (t.id = s.id) \n WHEN MATCHED THEN\n UPDATE "
175- "SET t.name = s.name, t.age = s.age, t.embedding = s.embedding\n WHEN NOT MATCHED THEN\n INSERT "
176- "(id, name, age, embedding) VALUES (s.id, s.name, s.age, s.embedding) \n OUTPUT inserted.id "
177- "INTO @UpsertedKeys (KeyColumn);\n SELECT KeyColumn FROM @UpsertedKeys;\n "
193+ "DECLARE @UpsertedKeys TABLE (KeyColumn nvarchar(255));\n "
194+ "MERGE INTO [dbo].[Test] AS t\n "
195+ "USING ( VALUES (?, ?, ?, ?) ) AS s ([id], [name], [age], [embedding]) "
196+ "ON (t.[id] = s.[id]) \n "
197+ "WHEN MATCHED THEN\n "
198+ "UPDATE SET t.[name] = s.[name], t.[age] = s.[age], t.[embedding] = s.[embedding]\n "
199+ "WHEN NOT MATCHED THEN\n "
200+ "INSERT ([id], [name], [age], [embedding]) "
201+ "VALUES (s.[id], s.[name], s.[age], s.[embedding]) \n "
202+ "OUTPUT inserted.[id] INTO @UpsertedKeys (KeyColumn);\n "
203+ "SELECT KeyColumn FROM @UpsertedKeys;\n "
178204 )
179205
180206 def test_build_select_query (self ):
@@ -192,7 +218,7 @@ def test_build_select_query(self):
192218 cmd = _build_select_query (schema , table , key_field , data_fields , vector_fields , keys )
193219 assert cmd .parameters == ["test" ]
194220 str_cmd = str (cmd )
195- assert str_cmd == "SELECT\n id, name, age, embedding FROM [dbo].[Test] \n WHERE id IN\n (?) ;"
221+ assert str_cmd == "SELECT\n [id], [ name], [ age], [ embedding] FROM [dbo].[Test] \n WHERE [id] IN\n (?) ;"
196222
197223 def test_build_delete_query (self ):
198224 schema = "dbo"
@@ -230,7 +256,7 @@ def test_build_search_query(self):
230256 assert cmd .parameters [0 ] == json .dumps (vector )
231257 str_cmd = str (cmd )
232258 assert (
233- str_cmd == "SELECT id, name, age, VECTOR_DISTANCE('cosine', embedding, CAST(? AS VECTOR(5))) as "
259+ str_cmd == "SELECT [id], [ name], [ age] , VECTOR_DISTANCE('cosine', [ embedding] , CAST(? AS VECTOR(5))) as "
234260 "_vector_distance_value\n FROM [dbo].[Test] \n ORDER BY "
235261 "_vector_distance_value ASC\n OFFSET 0 ROWS FETCH NEXT 3 ROWS ONLY;"
236262 )
@@ -378,11 +404,17 @@ async def test_upsert(
378404 await collection .upsert (record )
379405 mock_connection .cursor .return_value .__enter__ .return_value .execute .assert_called_with (
380406 (
381- "DECLARE @UpsertedKeys TABLE (KeyColumn nvarchar(255));\n MERGE INTO [dbo].[test] AS t\n USING ( VALUES"
382- " (?, ?, ?) ) AS s (id, content, vector) ON (t.id = s.id) \n WHEN MATCHED THEN\n UPDATE SET t.content"
383- " = s.content, t.vector = s.vector\n WHEN NOT MATCHED THEN\n INSERT (id, content, vector) VALUES (s.id, "
384- "s.content, s.vector) \n OUTPUT inserted.id INTO @UpsertedKeys (KeyColumn);\n SELECT KeyColumn "
385- "FROM @UpsertedKeys;\n "
407+ "DECLARE @UpsertedKeys TABLE (KeyColumn nvarchar(255));\n "
408+ "MERGE INTO [dbo].[test] AS t\n "
409+ "USING ( VALUES (?, ?, ?) ) AS s ([id], [content], [vector]) "
410+ "ON (t.[id] = s.[id]) \n "
411+ "WHEN MATCHED THEN\n "
412+ "UPDATE SET t.[content] = s.[content], t.[vector] = s.[vector]\n "
413+ "WHEN NOT MATCHED THEN\n "
414+ "INSERT ([id], [content], [vector]) "
415+ "VALUES (s.[id], s.[content], s.[vector]) \n "
416+ "OUTPUT inserted.[id] INTO @UpsertedKeys (KeyColumn);\n "
417+ "SELECT KeyColumn FROM @UpsertedKeys;\n "
386418 ),
387419 ("1" , "test" , json .dumps ([0.1 , 0.2 , 0.3 , 0.4 , 0.5 ])),
388420 )
@@ -415,7 +447,7 @@ class MockRow(NamedTuple):
415447 mock_cursor .__iter__ .return_value = [row ]
416448 record = await collection .get (key , include_vectors = True )
417449 mock_cursor .execute .assert_called_with (
418- "SELECT\n id, content, vector FROM [dbo].[test] \n WHERE id IN\n (?) ;" , ("1" ,)
450+ "SELECT\n [id], [ content], [ vector] FROM [dbo].[test] \n WHERE [id] IN\n (?) ;" , ("1" ,)
419451 )
420452 assert record ["id" ] == "1"
421453 assert record ["content" ] == "test"
@@ -478,7 +510,7 @@ class MockRow:
478510 assert record .score == 0.1
479511 mock_cursor .execute .assert_called_with (
480512 (
481- "SELECT id, content, VECTOR_DISTANCE('cosine', vector, CAST(? AS VECTOR(5))) as "
513+ "SELECT [id], [ content] , VECTOR_DISTANCE('cosine', [ vector] , CAST(? AS VECTOR(5))) as "
482514 "_vector_distance_value\n FROM [dbo].[test] \n WHERE [content] = ? \n ORDER BY _vector_distance_value "
483515 "ASC\n OFFSET 0 ROWS FETCH NEXT 3 ROWS ONLY;"
484516 ),
@@ -502,8 +534,8 @@ async def test_ensure_collection_exists(
502534 await collection .ensure_collection_exists ()
503535 mock_connection .cursor .return_value .__enter__ .return_value .execute .assert_called_with (
504536 (
505- "IF OBJECT_ID(N' [dbo].[test] ', N'U') IS NULL\n BEGIN\n CREATE TABLE [dbo].[test] \n (\" id \" nvarchar"
506- ' (255) NOT NULL,\n " content" nvarchar(max) NULL,\n " vector" VECTOR(5) NULL,\n PRIMARY KEY (id ) \n ) ;'
537+ "IF OBJECT_ID(N' [dbo].[test] ', N'U') IS NULL\n BEGIN\n CREATE TABLE [dbo].[test] \n ([id] nvarchar"
538+ " (255) NOT NULL,\n [ content] nvarchar(max) NULL,\n [ vector] VECTOR(5) NULL,\n PRIMARY KEY ([id] ) \n ) ;"
507539 "\n END\n "
508540 ),
509541 (),
0 commit comments