Skip to content

Commit 0c3aec1

Browse files
authored
feat: support statement tags in hints (#1579)
1 parent 10cc93b commit 0c3aec1

File tree

2 files changed

+160
-21
lines changed

2 files changed

+160
-21
lines changed
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.cloud.spanner.jdbc;
18+
19+
import static org.junit.Assert.assertEquals;
20+
import static org.junit.Assert.assertFalse;
21+
import static org.junit.Assert.assertTrue;
22+
23+
import com.google.cloud.spanner.Dialect;
24+
import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult;
25+
import com.google.cloud.spanner.connection.AbstractMockServerTest;
26+
import com.google.spanner.v1.ExecuteSqlRequest;
27+
import com.google.spanner.v1.RequestOptions.Priority;
28+
import java.sql.Connection;
29+
import java.sql.DriverManager;
30+
import java.sql.ResultSet;
31+
import java.sql.SQLException;
32+
import org.junit.After;
33+
import org.junit.Before;
34+
import org.junit.Test;
35+
import org.junit.runner.RunWith;
36+
import org.junit.runners.Parameterized;
37+
import org.junit.runners.Parameterized.Parameter;
38+
import org.junit.runners.Parameterized.Parameters;
39+
40+
@RunWith(Parameterized.class)
41+
public class ClientSideStatementHintsTest extends AbstractMockServerTest {
42+
43+
@Parameter public Dialect dialect;
44+
45+
private Dialect currentDialect;
46+
47+
@Parameters(name = "dialect = {0}")
48+
public static Object[] data() {
49+
return Dialect.values();
50+
}
51+
52+
@Before
53+
public void setupDialect() {
54+
if (this.dialect != currentDialect) {
55+
mockSpanner.putStatementResult(StatementResult.detectDialectResult(this.dialect));
56+
this.currentDialect = dialect;
57+
}
58+
}
59+
60+
@After
61+
public void clearRequests() {
62+
mockSpanner.clearRequests();
63+
}
64+
65+
private String createUrl() {
66+
return String.format(
67+
"jdbc:cloudspanner://localhost:%d/projects/%s/instances/%s/databases/%s?usePlainText=true",
68+
getPort(), "proj", "inst", "db" + (dialect == Dialect.POSTGRESQL ? "pg" : ""));
69+
}
70+
71+
private Connection createConnection() throws SQLException {
72+
return DriverManager.getConnection(createUrl());
73+
}
74+
75+
@Test
76+
public void testStatementTagInHint() throws SQLException {
77+
try (Connection connection = createConnection()) {
78+
try (ResultSet resultSet =
79+
connection
80+
.createStatement()
81+
.executeQuery(
82+
dialect == Dialect.POSTGRESQL
83+
? "/*@statement_tag='test-tag'*/SELECT 1"
84+
: "@{statement_tag='test-tag'}SELECT 1")) {
85+
assertTrue(resultSet.next());
86+
assertEquals(1L, resultSet.getLong(1));
87+
assertFalse(resultSet.next());
88+
}
89+
}
90+
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
91+
ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0);
92+
assertEquals("test-tag", request.getRequestOptions().getRequestTag());
93+
}
94+
95+
@Test
96+
public void testRpcPriorityInHint() throws SQLException {
97+
try (Connection connection = createConnection()) {
98+
try (ResultSet resultSet =
99+
connection
100+
.createStatement()
101+
.executeQuery(
102+
dialect == Dialect.POSTGRESQL
103+
? "/*@rpc_priority=PRIORITY_LOW*/SELECT 1"
104+
: "@{rpc_priority=PRIORITY_LOW}SELECT 1")) {
105+
assertTrue(resultSet.next());
106+
assertEquals(1L, resultSet.getLong(1));
107+
assertFalse(resultSet.next());
108+
}
109+
}
110+
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
111+
ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0);
112+
assertEquals(Priority.PRIORITY_LOW, request.getRequestOptions().getPriority());
113+
}
114+
}

src/test/java/com/google/cloud/spanner/jdbc/TagMockServerTest.java

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import static org.junit.Assert.assertArrayEquals;
2020
import static org.junit.Assert.assertEquals;
21+
import static org.junit.Assert.assertFalse;
2122
import static org.junit.Assert.assertTrue;
2223

2324
import com.google.cloud.spanner.Dialect;
@@ -40,16 +41,25 @@
4041

4142
@RunWith(Parameterized.class)
4243
public class TagMockServerTest extends AbstractMockServerTest {
44+
private static final String SELECT_RANDOM_SQL = SELECT_RANDOM_STATEMENT.getSql();
45+
46+
private static final String INSERT_SQL = INSERT_STATEMENT.getSql();
47+
4348
@Parameter public Dialect dialect;
4449

50+
private Dialect currentDialect;
51+
4552
@Parameters(name = "dialect = {0}")
4653
public static Object[] data() {
4754
return Dialect.values();
4855
}
4956

5057
@Before
5158
public void setupDialect() {
52-
mockSpanner.putStatementResult(StatementResult.detectDialectResult(this.dialect));
59+
if (this.dialect != currentDialect) {
60+
mockSpanner.putStatementResult(StatementResult.detectDialectResult(this.dialect));
61+
this.currentDialect = dialect;
62+
}
5363
}
5464

5565
@After
@@ -77,8 +87,7 @@ public void testStatementTagForQuery() throws SQLException {
7787
connection
7888
.createStatement()
7989
.execute(String.format("set %sstatement_tag='my-tag'", getVariablePrefix()));
80-
try (ResultSet resultSet =
81-
connection.createStatement().executeQuery(SELECT_RANDOM_STATEMENT.getSql())) {
90+
try (ResultSet resultSet = connection.createStatement().executeQuery(SELECT_RANDOM_SQL)) {
8291
assertTrue(resultSet.next());
8392
}
8493
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
@@ -87,8 +96,7 @@ public void testStatementTagForQuery() throws SQLException {
8796

8897
// Verify that the tag is cleared after having been used.
8998
mockSpanner.clearRequests();
90-
try (ResultSet resultSet =
91-
connection.createStatement().executeQuery(SELECT_RANDOM_STATEMENT.getSql())) {
99+
try (ResultSet resultSet = connection.createStatement().executeQuery(SELECT_RANDOM_SQL)) {
92100
assertTrue(resultSet.next());
93101
}
94102
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
@@ -103,8 +111,7 @@ public void testTransactionTagForQuery() throws SQLException {
103111
connection
104112
.createStatement()
105113
.execute(String.format("set %stransaction_tag='my-tag'", getVariablePrefix()));
106-
try (ResultSet resultSet =
107-
connection.createStatement().executeQuery(SELECT_RANDOM_STATEMENT.getSql())) {
114+
try (ResultSet resultSet = connection.createStatement().executeQuery(SELECT_RANDOM_SQL)) {
108115
assertTrue(resultSet.next());
109116
}
110117
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
@@ -113,8 +120,7 @@ public void testTransactionTagForQuery() throws SQLException {
113120

114121
// Verify that the tag is used for the entire transaction.
115122
mockSpanner.clearRequests();
116-
try (ResultSet resultSet =
117-
connection.createStatement().executeQuery(SELECT_RANDOM_STATEMENT.getSql())) {
123+
try (ResultSet resultSet = connection.createStatement().executeQuery(SELECT_RANDOM_SQL)) {
118124
assertTrue(resultSet.next());
119125
}
120126
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
@@ -125,8 +131,7 @@ public void testTransactionTagForQuery() throws SQLException {
125131
connection.commit();
126132

127133
mockSpanner.clearRequests();
128-
try (ResultSet resultSet =
129-
connection.createStatement().executeQuery(SELECT_RANDOM_STATEMENT.getSql())) {
134+
try (ResultSet resultSet = connection.createStatement().executeQuery(SELECT_RANDOM_SQL)) {
130135
assertTrue(resultSet.next());
131136
}
132137
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
@@ -143,8 +148,8 @@ public void testStatementTagForBatchDml() throws SQLException {
143148
.execute(String.format("set %sstatement_tag='my-tag'", getVariablePrefix()));
144149

145150
try (Statement statement = connection.createStatement()) {
146-
statement.addBatch(INSERT_STATEMENT.getSql());
147-
statement.addBatch(INSERT_STATEMENT.getSql());
151+
statement.addBatch(INSERT_SQL);
152+
statement.addBatch(INSERT_SQL);
148153
assertArrayEquals(new int[] {1, 1}, statement.executeBatch());
149154
}
150155
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class));
@@ -155,8 +160,8 @@ public void testStatementTagForBatchDml() throws SQLException {
155160
// Verify that the tag is cleared after having been used.
156161
mockSpanner.clearRequests();
157162
try (Statement statement = connection.createStatement()) {
158-
statement.addBatch(INSERT_STATEMENT.getSql());
159-
statement.addBatch(INSERT_STATEMENT.getSql());
163+
statement.addBatch(INSERT_SQL);
164+
statement.addBatch(INSERT_SQL);
160165
assertArrayEquals(new int[] {1, 1}, statement.executeBatch());
161166
}
162167
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class));
@@ -173,8 +178,8 @@ public void testTransactionTagForBatchDml() throws SQLException {
173178
.execute(String.format("set %stransaction_tag='my-tag'", getVariablePrefix()));
174179

175180
try (Statement statement = connection.createStatement()) {
176-
statement.addBatch(INSERT_STATEMENT.getSql());
177-
statement.addBatch(INSERT_STATEMENT.getSql());
181+
statement.addBatch(INSERT_SQL);
182+
statement.addBatch(INSERT_SQL);
178183
assertArrayEquals(new int[] {1, 1}, statement.executeBatch());
179184
}
180185
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class));
@@ -185,8 +190,8 @@ public void testTransactionTagForBatchDml() throws SQLException {
185190
// Verify that the tag is used for the entire transaction.
186191
mockSpanner.clearRequests();
187192
try (Statement statement = connection.createStatement()) {
188-
statement.addBatch(INSERT_STATEMENT.getSql());
189-
statement.addBatch(INSERT_STATEMENT.getSql());
193+
statement.addBatch(INSERT_SQL);
194+
statement.addBatch(INSERT_SQL);
190195
assertArrayEquals(new int[] {1, 1}, statement.executeBatch());
191196
}
192197
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class));
@@ -197,13 +202,33 @@ public void testTransactionTagForBatchDml() throws SQLException {
197202
connection.commit();
198203
mockSpanner.clearRequests();
199204
try (Statement statement = connection.createStatement()) {
200-
statement.addBatch(INSERT_STATEMENT.getSql());
201-
statement.addBatch(INSERT_STATEMENT.getSql());
205+
statement.addBatch(INSERT_SQL);
206+
statement.addBatch(INSERT_SQL);
202207
assertArrayEquals(new int[] {1, 1}, statement.executeBatch());
203208
}
204209
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class));
205210
request = mockSpanner.getRequestsOfType(ExecuteBatchDmlRequest.class).get(0);
206211
assertEquals("", request.getRequestOptions().getTransactionTag());
207212
}
208213
}
214+
215+
@Test
216+
public void testStatementTagInHint() throws SQLException {
217+
try (Connection connection = createConnection()) {
218+
try (ResultSet resultSet =
219+
connection
220+
.createStatement()
221+
.executeQuery(
222+
dialect == Dialect.POSTGRESQL
223+
? "/*@statement_tag='test-tag'*/SELECT 1"
224+
: "@{statement_tag='test-tag'}SELECT 1")) {
225+
assertTrue(resultSet.next());
226+
assertEquals(1L, resultSet.getLong(1));
227+
assertFalse(resultSet.next());
228+
}
229+
}
230+
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
231+
ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0);
232+
assertEquals("test-tag", request.getRequestOptions().getRequestTag());
233+
}
209234
}

0 commit comments

Comments
 (0)