Skip to content

Commit 40585cf

Browse files
authored
PCSM-200: Shard collections with non-default collation (#139)
1 parent 9239fee commit 40585cf

File tree

4 files changed

+73
-15
lines changed

4 files changed

+73
-15
lines changed

plm/catalog.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,22 @@ func (c *Catalog) CreateIndexes(
378378
lg.Info("Create hidden index as unhidden: " + index.Name)
379379
}
380380

381+
if index.Collation == nil {
382+
lg.Info("Create index with missing collation, setting simple collation: " + index.Name)
383+
384+
d := bson.D{{"locale", "simple"}}
385+
386+
collation, err := bson.Marshal(d)
387+
if err != nil {
388+
return errors.Wrapf(err, "marshal simple collation for index %s.%s.%s",
389+
db, coll, index.Name)
390+
}
391+
392+
idxCopy := *index
393+
idxCopy.Collation = collation
394+
index = &idxCopy
395+
}
396+
381397
idxs = append(idxs, index)
382398
}
383399

@@ -1147,6 +1163,7 @@ func (c *Catalog) ShardCollection(
11471163
cmd := bson.D{
11481164
{"shardCollection", db + "." + coll},
11491165
{"key", shardKey},
1166+
{"collation", bson.D{{"locale", "simple"}}},
11501167
}
11511168

11521169
if unique {

plm/clone.go

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -417,12 +417,7 @@ func (c *Clone) doCollectionClone(
417417
return ErrTimeseriesUnsupported
418418
}
419419

420-
shInfo, err := topo.GetCollectionShardingInfo(ctx, c.source, ns.Database, ns.Collection)
421-
if err != nil && !errors.Is(err, topo.ErrNotFound) {
422-
return errors.Wrap(err, "get sharding info")
423-
}
424-
425-
err = c.createCollection(ctx, ns, spec, shInfo)
420+
err = c.createCollection(ctx, ns, spec)
426421
if err != nil {
427422
if !errors.Is(err, context.Canceled) {
428423
lg.Errorf(err, "Failed to create %q collection", ns.String())
@@ -438,7 +433,21 @@ func (c *Clone) doCollectionClone(
438433
}
439434
}
440435

441-
lg.Infof("Collection %q has been created", ns.String())
436+
lg.Infof("Collection %q created", ns.String())
437+
438+
shInfo, err := topo.GetCollectionShardingInfo(ctx, c.source, ns.Database, ns.Collection)
439+
if err != nil && !errors.Is(err, topo.ErrNotFound) {
440+
return errors.Wrap(err, "get sharding info")
441+
}
442+
443+
if shInfo != nil && shInfo.IsSharded() {
444+
err := c.catalog.ShardCollection(ctx, ns.Database, ns.Collection, shInfo.ShardKey, shInfo.Unique)
445+
if err != nil {
446+
return errors.Wrap(err, "shard collection")
447+
}
448+
}
449+
450+
lg.Infof("Collection %q sharded", ns.String())
442451

443452
c.catalog.SetCollectionTimestamp(ctx, ns.Database, ns.Collection, capturedAt)
444453

@@ -694,7 +703,6 @@ func (c *Clone) createCollection(
694703
ctx context.Context,
695704
ns Namespace,
696705
spec *topo.CollectionSpecification,
697-
shInfo *topo.ShardingInfo,
698706
) error {
699707
if spec.Type == topo.TypeTimeseries {
700708
return ErrTimeseriesUnsupported
@@ -717,13 +725,6 @@ func (c *Clone) createCollection(
717725
return errors.Wrap(err, "create collection")
718726
}
719727

720-
if shInfo != nil && shInfo.IsSharded() {
721-
err := c.catalog.ShardCollection(ctx, ns.Database, ns.Collection, shInfo.ShardKey, shInfo.Unique)
722-
if err != nil {
723-
return errors.Wrap(err, "shard collection")
724-
}
725-
}
726-
727728
return nil
728729
}
729730

tests/test_collections_sharded.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,33 @@ def test_rename_sharded(t: Testing, phase: Runner.Phase):
3838
t.source["db_1"]["coll_1"].rename("coll_2")
3939

4040
t.compare_all()
41+
42+
43+
@pytest.mark.parametrize("phase", [Runner.Phase.APPLY, Runner.Phase.CLONE])
44+
def test_create_collection_with_collation(t: Testing, phase: Runner.Phase):
45+
with t.run(phase):
46+
t.source["db_1"].create_collection("coll_1", collation={"locale": "en", "strength": 2})
47+
t.source.admin.command(
48+
"shardCollection", "db_1.coll_1", key={"name": 1}, collation={"locale": "simple"}
49+
)
50+
51+
t.compare_all_sharded()
52+
53+
54+
@pytest.mark.parametrize("phase", [Runner.Phase.APPLY, Runner.Phase.CLONE])
55+
def test_create_collection_with_collation_with_shard_key_index_prefix(
56+
t: Testing, phase: Runner.Phase
57+
):
58+
with t.run(phase):
59+
t.source["db_1"].create_collection("coll_2", collation={"locale": "en", "strength": 2})
60+
t.source["db_1"]["coll_2"].create_index(
61+
[("name", 1), ("date", 1), ("age", 1)], collation={"locale": "simple"}
62+
)
63+
t.source.admin.command(
64+
"shardCollection",
65+
"db_1.coll_1",
66+
key={"name": 1, "date": 1},
67+
collation={"locale": "simple"},
68+
)
69+
70+
t.compare_all_sharded()

tests/test_indexes.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ def test_create_with_collation(t: Testing, phase: Runner.Phase):
2424
t.compare_all()
2525

2626

27+
@pytest.mark.parametrize("phase", [Runner.Phase.APPLY, Runner.Phase.CLONE])
28+
def test_create_with_inherited_collation(t: Testing, phase: Runner.Phase):
29+
with t.run(phase):
30+
t.source["db_1"].create_collection("coll_1_collation", collation={"locale": "en_US"})
31+
t.source["db_1"]["coll_1_collation"].create_index({"i": 1})
32+
t.source["db_1"]["coll_1_collation"].create_index({"j": 1}, collation={"locale": "simple"})
33+
34+
t.compare_all()
35+
36+
2737
@pytest.mark.parametrize("phase", [Runner.Phase.APPLY, Runner.Phase.CLONE])
2838
def test_create_unique(t: Testing, phase: Runner.Phase):
2939
with t.run(phase):

0 commit comments

Comments
 (0)