Skip to content

Commit 75bf1d4

Browse files
committed
Update SpannerSchema to support postgres tokenlist
* Refactor to use a immutablemap for fast mapping
1 parent 73b4d53 commit 75bf1d4

File tree

2 files changed

+42
-60
lines changed

2 files changed

+42
-60
lines changed

sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerSchema.java

Lines changed: 38 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import com.google.cloud.spanner.Type;
2323
import java.io.Serializable;
2424
import java.util.List;
25+
import java.util.Map;
2526
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
2627
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
2728
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap;
@@ -161,54 +162,59 @@ static Column create(String name, String spannerType, Dialect dialect) {
161162

162163
public abstract Type getType();
163164

165+
private static final Map<String, Type> GOOGLE_STANDARD_SQL_TYPE_MAP =
166+
ImmutableMap.<String, Type>builder()
167+
.put("BOOL", Type.bool())
168+
.put("INT64", Type.int64())
169+
.put("FLOAT32", Type.float32())
170+
.put("FLOAT64", Type.float64())
171+
.put("UUID", Type.string())
172+
.put("TOKENLIST", Type.bytes())
173+
.put("TIMESTAMP", Type.timestamp())
174+
.put("DATE", Type.date())
175+
.put("NUMERIC", Type.numeric())
176+
.put("JSON", Type.json())
177+
.build();
178+
private static final Map<String, Type> POSTGRES_TYPE_MAP =
179+
ImmutableMap.<String, Type>builder()
180+
.put("BOOLEAN", Type.bool())
181+
.put("BIGINT", Type.int64())
182+
.put("REAL", Type.float32())
183+
.put("DOUBLE PRECISION", Type.float64())
184+
.put("TEXT", Type.string())
185+
.put("BYTEA", Type.bytes())
186+
.put("TIMESTAMP WITH TIME ZONE", Type.timestamp())
187+
.put("DATE", Type.date())
188+
.put("SPANNER.COMMIT_TIMESTAMP", Type.timestamp())
189+
.put("SPANNER.TOKENLIST", Type.bytes())
190+
.put("UUID", Type.string())
191+
.build();
192+
164193
private static Type parseSpannerType(String spannerType, Dialect dialect) {
165194
String originalSpannerType = spannerType;
166195
spannerType = spannerType.toUpperCase();
167196
switch (dialect) {
168197
case GOOGLE_STANDARD_SQL:
169-
if ("BOOL".equals(spannerType)) {
170-
return Type.bool();
171-
}
172-
if ("INT64".equals(spannerType)) {
173-
return Type.int64();
174-
}
175-
if ("FLOAT32".equals(spannerType)) {
176-
return Type.float32();
177-
}
178-
if ("FLOAT64".equals(spannerType)) {
179-
return Type.float64();
198+
Type type = GOOGLE_STANDARD_SQL_TYPE_MAP.get(spannerType);
199+
if (type != null) {
200+
return type;
180201
}
181202
if (spannerType.startsWith("STRING")) {
182203
return Type.string();
183204
}
184-
if ("UUID".equals(spannerType)) {
185-
return Type.string();
186-
}
187205
if (spannerType.startsWith("BYTES")) {
188206
return Type.bytes();
189207
}
190-
if ("TOKENLIST".equals(spannerType)) {
191-
return Type.bytes();
192-
}
193-
if ("TIMESTAMP".equals(spannerType)) {
194-
return Type.timestamp();
195-
}
196-
if ("DATE".equals(spannerType)) {
197-
return Type.date();
198-
}
199-
if ("NUMERIC".equals(spannerType)) {
200-
return Type.numeric();
201-
}
202-
if ("JSON".equals(spannerType)) {
203-
return Type.json();
204-
}
205208
if (spannerType.startsWith("ARRAY")) {
206209
// Substring "ARRAY<xxx>"
207210
String spannerArrayType =
208211
originalSpannerType.substring(6, originalSpannerType.length() - 1);
209212
Type itemType = parseSpannerType(spannerArrayType, dialect);
210213
return Type.array(itemType);
211214
}
215+
if (spannerType.startsWith("CHARACTER VARYING")) {
216+
return Type.string();
217+
}
212218
if (spannerType.startsWith("PROTO")) {
213219
// Substring "PROTO<xxx>"
214220
String spannerProtoType =
@@ -230,42 +236,16 @@ private static Type parseSpannerType(String spannerType, Dialect dialect) {
230236
Type itemType = parseSpannerType(spannerArrayType, dialect);
231237
return Type.array(itemType);
232238
}
233-
if ("BOOLEAN".equals(spannerType)) {
234-
return Type.bool();
235-
}
236-
if ("BIGINT".equals(spannerType)) {
237-
return Type.int64();
238-
}
239-
if ("REAL".equals(spannerType)) {
240-
return Type.float32();
241-
}
242-
if ("DOUBLE PRECISION".equals(spannerType)) {
243-
return Type.float64();
244-
}
245-
if (spannerType.startsWith("CHARACTER VARYING") || "TEXT".equals(spannerType)) {
246-
return Type.string();
247-
}
248-
if ("BYTEA".equals(spannerType)) {
249-
return Type.bytes();
250-
}
251-
if ("TIMESTAMP WITH TIME ZONE".equals(spannerType)) {
252-
return Type.timestamp();
253-
}
254-
if ("DATE".equals(spannerType)) {
255-
return Type.date();
239+
type = POSTGRES_TYPE_MAP.get(spannerType);
240+
if (type != null) {
241+
return type;
256242
}
257243
if (spannerType.startsWith("NUMERIC")) {
258244
return Type.pgNumeric();
259245
}
260-
if ("SPANNER.COMMIT_TIMESTAMP".equals(spannerType)) {
261-
return Type.timestamp();
262-
}
263246
if (spannerType.startsWith("JSONB")) {
264247
return Type.pgJsonb();
265248
}
266-
if ("UUID".equals(spannerType)) {
267-
return Type.string();
268-
}
269249
throw new IllegalArgumentException("Unknown spanner type " + spannerType);
270250
default:
271251
throw new IllegalArgumentException("Unrecognized dialect: " + dialect.name());

sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerSchemaTest.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,16 @@ public void testSinglePgTable() throws Exception {
8686
.addColumn("test", "numericVal", "numeric")
8787
.addColumn("test", "commitTime", "spanner.commit_timestamp")
8888
.addColumn("test", "jsonbCol", "jsonb")
89+
.addColumn("test", "tokens", "spanner.tokenlist")
8990
.addColumn("test", "uuidCol", "uuid")
9091
.build();
9192

9293
assertEquals(1, schema.getTables().size());
93-
assertEquals(6, schema.getColumns("test").size());
94+
assertEquals(7, schema.getColumns("test").size());
9495
assertEquals(1, schema.getKeyParts("test").size());
9596
assertEquals(Type.timestamp(), schema.getColumns("test").get(3).getType());
96-
assertEquals(Type.string(), schema.getColumns("test").get(5).getType());
97+
assertEquals(Type.bytes(), schema.getColumns("test").get(5).getType());
98+
assertEquals(Type.string(), schema.getColumns("test").get(6).getType());
9799
}
98100

99101
@Test

0 commit comments

Comments
 (0)