Skip to content

Commit 04c18ad

Browse files
authored
Allow ~ and fix for schema_export (#46)
* Update dolt.py * Update schema_export to match cli the --filename flag appears to have been removed from the cli * Add test * Fix the test
1 parent a9fc7b2 commit 04c18ad

File tree

2 files changed

+26
-30
lines changed

2 files changed

+26
-30
lines changed

doltcli/dolt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ class Dolt(DoltT):
267267
"""
268268

269269
def __init__(self, repo_dir: str, print_output: Optional[bool] = None):
270+
# allow ~ to be used in paths
271+
repo_dir = os.path.expanduser(repo_dir)
270272
self.repo_dir = repo_dir
271273
self._print_output = print_output or False
272274

@@ -1366,7 +1368,7 @@ def schema_export(self, table: str, filename: Optional[str] = None):
13661368
args = ["schema", "export", table]
13671369

13681370
if filename:
1369-
args.extend(["--filename", filename])
1371+
args.extend([filename])
13701372
_execute(args, self.repo_dir)
13711373
return True
13721374
else:

tests/test_dolt.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,18 @@ def test_init(tmp_path):
113113
shutil.rmtree(repo_data_dir)
114114

115115

116+
def test_home_path():
117+
path = "~/.dolt_test"
118+
if os.path.exists(os.path.expanduser(path)):
119+
shutil.rmtree(os.path.expanduser(path))
120+
os.mkdir(os.path.expanduser(path))
121+
# Create empty file
122+
open(os.path.expanduser(path + "/.dolt"), "a").close()
123+
Dolt(path)
124+
assert os.path.exists(path)
125+
shutil.rmtree(path)
126+
127+
116128
def test_bad_repo_path(tmp_path):
117129
bad_repo_path = tmp_path
118130
with pytest.raises(ValueError):
@@ -205,10 +217,10 @@ def test_merge_conflict(create_test_table: Tuple[Dolt, str]):
205217
with pytest.raises(DoltException):
206218
repo.merge("other", message_merge)
207219

208-
#commits = list(repo.log().values())
209-
#head_of_main = commits[0]
220+
# commits = list(repo.log().values())
221+
# head_of_main = commits[0]
210222

211-
#assert head_of_main.message == message_two
223+
# assert head_of_main.message == message_two
212224

213225

214226
def test_dolt_log(create_test_table: Tuple[Dolt, str]):
@@ -400,10 +412,7 @@ def test_branch(create_test_table: Tuple[Dolt, str]):
400412
repo.checkout("dosac", checkout_branch=True)
401413
repo.checkout("main")
402414
next_active_branch, next_branches = repo.branch()
403-
assert (
404-
set(branch.name for branch in next_branches) == {"main", "dosac"}
405-
and next_active_branch.name == "main"
406-
)
415+
assert set(branch.name for branch in next_branches) == {"main", "dosac"} and next_active_branch.name == "main"
407416

408417
repo.checkout("dosac")
409418
different_active_branch, _ = repo.branch()
@@ -552,17 +561,13 @@ def test_sql(create_test_table: Tuple[Dolt, str]):
552561

553562
def test_sql_json(create_test_table: Tuple[Dolt, str]):
554563
repo, test_table = create_test_table
555-
result = repo.sql(
556-
query="SELECT * FROM `{table}`".format(table=test_table), result_format="json"
557-
)["rows"]
564+
result = repo.sql(query="SELECT * FROM `{table}`".format(table=test_table), result_format="json")["rows"]
558565
_verify_against_base_rows(result)
559566

560567

561568
def test_sql_csv(create_test_table: Tuple[Dolt, str]):
562569
repo, test_table = create_test_table
563-
result = repo.sql(
564-
query="SELECT * FROM `{table}`".format(table=test_table), result_format="csv"
565-
)
570+
result = repo.sql(query="SELECT * FROM `{table}`".format(table=test_table), result_format="csv")
566571
_verify_against_base_rows(result)
567572

568573

@@ -604,10 +609,7 @@ def test_config_global(init_empty_test_repo: Dolt):
604609
Dolt.config_global(add=True, name="user.name", value=test_username)
605610
Dolt.config_global(add=True, name="user.email", value=test_email)
606611
updated_config = Dolt.config_global(list=True)
607-
assert (
608-
updated_config["user.name"] == test_username
609-
and updated_config["user.email"] == test_email
610-
)
612+
assert updated_config["user.name"] == test_username and updated_config["user.email"] == test_email
611613
Dolt.config_global(add=True, name="user.name", value=current_global_config["user.name"])
612614
Dolt.config_global(add=True, name="user.email", value=current_global_config["user.email"])
613615
reset_config = Dolt.config_global(list=True)
@@ -623,9 +625,7 @@ def test_config_local(init_empty_test_repo: Dolt):
623625
repo.config_local(add=True, name="user.email", value=test_email)
624626
local_config = repo.config_local(list=True)
625627
global_config = Dolt.config_global(list=True)
626-
assert (
627-
local_config["user.name"] == test_username and local_config["user.email"] == test_email
628-
)
628+
assert local_config["user.name"] == test_username and local_config["user.email"] == test_email
629629
assert global_config["user.name"] == current_global_config["user.name"]
630630
assert global_config["user.email"] == current_global_config["user.email"]
631631

@@ -677,18 +677,14 @@ def test_clone_new_dir(tmp_path):
677677
def test_dolt_sql_csv(init_empty_test_repo: Dolt):
678678
dolt = init_empty_test_repo
679679
write_rows(dolt, "test_table", BASE_TEST_ROWS, commit=True)
680-
result = dolt.sql(
681-
"SELECT `name` as name, `id` as id FROM test_table ORDER BY id", result_format="csv"
682-
)
680+
result = dolt.sql("SELECT `name` as name, `id` as id FROM test_table ORDER BY id", result_format="csv")
683681
compare_rows_helper(BASE_TEST_ROWS, result)
684682

685683

686684
def test_dolt_sql_json(init_empty_test_repo: Dolt):
687685
dolt = init_empty_test_repo
688686
write_rows(dolt, "test_table", BASE_TEST_ROWS, commit=True)
689-
result = dolt.sql(
690-
"SELECT `name` as name, `id` as id FROM test_table ", result_format="json"
691-
)
687+
result = dolt.sql("SELECT `name` as name, `id` as id FROM test_table ", result_format="json")
692688
# JSON return value preserves some type information, we cast back to a string
693689
for row in result["rows"]:
694690
row["id"] = str(row["id"])
@@ -700,9 +696,7 @@ def test_dolt_sql_file(init_empty_test_repo: Dolt):
700696

701697
with tempfile.NamedTemporaryFile() as f:
702698
write_rows(dolt, "test_table", BASE_TEST_ROWS, commit=True)
703-
result = dolt.sql(
704-
"SELECT `name` as name, `id` as id FROM test_table ", result_file=f.name
705-
)
699+
result = dolt.sql("SELECT `name` as name, `id` as id FROM test_table ", result_file=f.name)
706700
res = read_csv_to_dict(f.name)
707701
compare_rows_helper(BASE_TEST_ROWS, res)
708702

0 commit comments

Comments
 (0)