Skip to content

Commit 0797bb4

Browse files
committed
Enhance DML detection and command tag consistency
- Centralize SQL statement categorization in IRISSQLParser. - Add support for MERGE and TRUNCATE command tags. - Use centralized parser in protocol handler to ensure consistent batch flushing. - Refactor embedded execution path to use shared command tag logic.
1 parent e0bbef1 commit 0797bb4

File tree

4 files changed

+54
-48
lines changed

4 files changed

+54
-48
lines changed

src/iris_pgwire/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
caretdev/sqlalchemy-iris.
77
"""
88

9-
__version__ = "1.2.5"
9+
__version__ = "1.2.6"
1010
__author__ = "Thomas Dyar <thomas.dyar@intersystems.com>"
1111

1212
# Don't import server/protocol in __init__ to avoid sys.modules conflicts

src/iris_pgwire/iris_executor.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
SQLTranslator, # Feature 021: PostgreSQL→IRIS normalization
3131
TransactionTranslator,
3232
) # Feature 022: PostgreSQL transaction verb translation
33+
from .sql_translator.parser import get_parser
3334
from .sql_translator.alias_extractor import AliasExtractor # Column alias preservation
3435
from .sql_translator.performance_monitor import MetricType, PerformanceTracker, get_monitor
3536
from .type_mapping import (
@@ -104,6 +105,7 @@ def __init__(self, iris_config: dict[str, Any], server=None):
104105
self.sql_pipeline = SQLPipeline()
105106
self.sql_interceptor = SQLInterceptor(self)
106107
self.sql_translator = self.sql_pipeline.translator
108+
self.sql_parser = get_parser()
107109
self.transaction_translator = TransactionTranslator()
108110

109111
# Connection pool management
@@ -121,9 +123,9 @@ def __init__(self, iris_config: dict[str, Any], server=None):
121123

122124
logger.info(
123125
"IRIS executor initialized",
124-
host=iris_config.get("host"),
125-
port=iris_config.get("port"),
126-
namespace=iris_config.get("namespace"),
126+
host=self.iris_config.get("host"),
127+
port=self.iris_config.get("port"),
128+
namespace=self.iris_config.get("namespace"),
127129
embedded_mode=self.embedded_mode,
128130
)
129131

@@ -3241,15 +3243,7 @@ def _sync_execute():
32413243
session_id=session_id,
32423244
)
32433245
# Determine command tag from SQL
3244-
sql_upper = sql.strip().upper()
3245-
if sql_upper.startswith("DELETE"):
3246-
command_tag = "DELETE"
3247-
elif sql_upper.startswith("UPDATE"):
3248-
command_tag = "UPDATE"
3249-
elif sql_upper.startswith("INSERT"):
3250-
command_tag = "INSERT"
3251-
else:
3252-
command_tag = "UNKNOWN"
3246+
command_tag = self._determine_command_tag(sql, 0)
32533247

32543248
return {
32553249
"success": True, # SQLCODE 100 is success!
@@ -4380,30 +4374,27 @@ def _map_iris_type_to_oid(self, iris_type: str) -> int:
43804374

43814375
def _determine_command_tag(self, sql: str, row_count: int) -> str:
43824376
"""Determine PostgreSQL command tag from SQL"""
4383-
sql_upper = sql.upper().strip()
4377+
# Normalize: strip and get first word
4378+
sql_clean = sql.strip().upper()
4379+
if not sql_clean:
4380+
return "UNKNOWN"
43844381

4385-
if sql_upper.startswith("SELECT"):
4382+
first_word = sql_clean.split()[0] if sql_clean.split() else ""
4383+
4384+
if first_word == "SELECT":
43864385
return "SELECT"
4387-
elif sql_upper.startswith("INSERT"):
4386+
elif first_word == "INSERT":
43884387
return f"INSERT 0 {row_count}"
4389-
elif sql_upper.startswith("UPDATE"):
4388+
elif first_word == "UPDATE":
43904389
return f"UPDATE {row_count}"
4391-
elif sql_upper.startswith("DELETE"):
4390+
elif first_word == "DELETE":
43924391
return f"DELETE {row_count}"
4393-
elif sql_upper.startswith("CREATE"):
4394-
return "CREATE"
4395-
elif sql_upper.startswith("DROP"):
4396-
return "DROP"
4397-
elif sql_upper.startswith("ALTER"):
4398-
return "ALTER"
4399-
elif sql_upper.startswith("BEGIN"):
4400-
return "BEGIN"
4401-
elif sql_upper.startswith("COMMIT"):
4402-
return "COMMIT"
4403-
elif sql_upper.startswith("ROLLBACK"):
4404-
return "ROLLBACK"
4405-
elif sql_upper.startswith("SHOW"):
4406-
return "SHOW"
4392+
elif first_word == "MERGE":
4393+
return f"MERGE {row_count}"
4394+
elif first_word == "TRUNCATE":
4395+
return "TRUNCATE"
4396+
elif first_word in ("CREATE", "DROP", "ALTER", "BEGIN", "COMMIT", "ROLLBACK", "SHOW"):
4397+
return first_word
44074398
else:
44084399
return "UNKNOWN"
44094400

src/iris_pgwire/protocol.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2906,24 +2906,26 @@ async def handle_describe_message(self, body: bytes):
29062906
connection_id=self.connection_id,
29072907
statement_name=name,
29082908
query_preview=query[:100],
2909-
is_select=query_upper.startswith("SELECT"),
2910-
is_show=query_upper.startswith("SHOW"),
2909+
is_select=self.iris_executor.sql_parser.is_select_statement(query),
2910+
is_show=self.iris_executor.sql_parser.is_show_statement(query),
29112911
)
29122912

29132913
# Check if query has RETURNING clause (INSERT/UPDATE/DELETE with RETURNING)
2914-
has_returning = "RETURNING" in query_upper
2914+
has_returning = self.iris_executor.sql_parser.has_returning_clause(query)
29152915

29162916
if (
2917-
query_upper.startswith("SELECT")
2918-
or query_upper.startswith("SHOW")
2917+
self.iris_executor.sql_parser.is_select_statement(query)
2918+
or self.iris_executor.sql_parser.is_show_statement(query)
29192919
or has_returning
29202920
):
29212921
# Execute metadata discovery to get column information
29222922
# Use LIMIT 0 pattern to avoid fetching actual data
29232923
# For RETURNING queries, we'll send synthetic column metadata based on RETURNING columns
29242924
try:
29252925
# Special handling for RETURNING queries - extract columns from RETURNING clause
2926-
if has_returning and not query_upper.startswith("SELECT"):
2926+
if has_returning and not self.iris_executor.sql_parser.is_select_statement(
2927+
query
2928+
):
29272929
import re
29282930

29292931
# Extract RETURNING columns from query
@@ -3085,14 +3087,19 @@ async def handle_describe_message(self, body: bytes):
30853087

30863088
# Batch Execution Interception:
30873089
# Short-circuit Describe portal for DML statements.
3088-
is_dml = any(query_upper.startswith(k) for k in ["INSERT", "UPDATE", "DELETE"])
3089-
has_returning = "RETURNING" in query_upper
3090-
3091-
if is_dml and not has_returning:
3090+
# DML portals (without RETURNING) don't have RowDescription.
3091+
is_dml = self.iris_executor.sql_parser.is_dml_statement(query)
3092+
if is_dml:
3093+
logger.info(
3094+
"Describe portal: DML statement (no row metadata)", portal_name=name
3095+
)
30923096
await self.send_no_data()
30933097
return
30943098

3095-
if query_upper.startswith("SELECT") or query_upper.startswith("SHOW"):
3099+
# Metadata discovery for SELECT/SHOW
3100+
if self.iris_executor.sql_parser.is_select_statement(
3101+
query
3102+
) or self.iris_executor.sql_parser.is_show_statement(query):
30963103
try:
30973104
result = await self.iris_executor.execute_query(
30983105
query, params=portal.get("params", [])
@@ -3277,9 +3284,8 @@ async def handle_execute_message(self, body: bytes):
32773284
# Intercept INSERT/UPDATE/DELETE (DML) without RETURNING for protocol-level batching.
32783285
# Standard PostgreSQL clients (psycopg3) send Sync every 5 rows, which is slow.
32793286
# We buffer parameters and send synthetic CommandComplete to keep client pipe full.
3280-
query_upper = query.strip().upper()
3281-
is_dml = any(query_upper.startswith(k) for k in ["INSERT", "UPDATE", "DELETE"])
3282-
has_returning = "RETURNING" in query_upper
3287+
is_dml = self.iris_executor.sql_parser.is_dml_statement(query)
3288+
has_returning = self.iris_executor.sql_parser.has_returning_clause(query)
32833289

32843290
if is_dml and not has_returning:
32853291
# Store SQL if first row in batch
@@ -3291,6 +3297,7 @@ async def handle_execute_message(self, body: bytes):
32913297

32923298
# Send synthetic CommandComplete immediately to client
32933299
# This tricks the client into sending the next row immediately.
3300+
query_upper = query.strip().upper()
32943301
tag = f"{query_upper.split()[0]} 0 1\x00".encode()
32953302
msg_len = 4 + len(tag)
32963303
self.writer.write(struct.pack("!cI", MSG_COMMAND_COMPLETE, msg_len) + tag)

src/iris_pgwire/sql_translator/parser.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,10 +593,18 @@ def is_select_statement(self, sql: str) -> bool:
593593
sql_clean = sql.strip().upper()
594594
return sql_clean.startswith("SELECT") or sql_clean.startswith("WITH")
595595

596+
def is_show_statement(self, sql: str) -> bool:
597+
"""Check if SQL is a SHOW statement"""
598+
return sql.strip().upper().startswith("SHOW")
599+
596600
def is_dml_statement(self, sql: str) -> bool:
597-
"""Check if SQL is a DML statement (INSERT, UPDATE, DELETE)"""
601+
"""Check if SQL is a DML statement (INSERT, UPDATE, DELETE, MERGE)"""
598602
sql_clean = sql.strip().upper()
599-
return any(sql_clean.startswith(stmt) for stmt in ["INSERT", "UPDATE", "DELETE"])
603+
return any(sql_clean.startswith(stmt) for stmt in ["INSERT", "UPDATE", "DELETE", "MERGE"])
604+
605+
def has_returning_clause(self, sql: str) -> bool:
606+
"""Check if SQL has a RETURNING clause"""
607+
return "RETURNING" in sql.upper()
600608

601609
def is_ddl_statement(self, sql: str) -> bool:
602610
"""Check if SQL is a DDL statement"""

0 commit comments

Comments
 (0)