Skip to content

Commit 484c9eb

Browse files
authored
Use Table column names in load! (#217)
* check Table.jl names in load! * nothing schema case
1 parent 59d033e commit 484c9eb

File tree

3 files changed

+60
-9
lines changed

3 files changed

+60
-9
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,5 @@ CMakeLists.txt.user
1010
/test/test2.sqlite
1111
/docs/build
1212
.vscode
13+
14+
Manifest.toml

src/tables.jl

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,20 @@ function createtable!(db::DB, nm::AbstractString, ::Tables.Schema{names, types};
119119
return execute(db, "CREATE $temp TABLE $ifnotexists $nm ($(join(columns, ',')))")
120120
end
121121

122+
struct TableInfo
123+
exists::Bool
124+
names::Vector{String}
125+
end
126+
function tableinfo(db::DB, name::AbstractString)
127+
table_info = Tables.columntable(DBInterface.execute(db, "pragma table_info($name)"))
128+
exists = table_info !== NamedTuple()
129+
if exists
130+
return TableInfo(exists, table_info.name)
131+
else
132+
return TableInfo(exists, String[])
133+
end
134+
end
135+
122136
"""
123137
source |> SQLite.load!(db::SQLite.DB, tablename::String; temp::Bool=false, ifnotexists::Bool=false, analyze::Bool=false)
124138
SQLite.load!(source, db, tablename; temp=false, ifnotexists=false, analyze::Bool=false)
@@ -136,22 +150,36 @@ load!(db::DB, table::AbstractString="sqlitejl_"*Random.randstring(5); kwargs...)
136150
function load!(itr, db::DB, name::AbstractString="sqlitejl_"*Random.randstring(5); kwargs...)
137151
# check if table exists
138152
nm = esc_id(name)
139-
status = execute(db, "pragma table_info($nm)")
153+
db_tableinfo = tableinfo(db, nm)
140154
rows = Tables.rows(itr)
141155
sch = Tables.schema(rows)
142-
return load!(sch, rows, db, nm, name, status == SQLITE_DONE; kwargs...)
156+
return load!(sch, rows, db, nm, name, db_tableinfo; kwargs...)
143157
end
144158

145159
checkdupnames(names) = length(unique(map(x->lowercase(String(x)), names))) == length(names) || error("duplicate case-insensitive column names detected; sqlite doesn't allow duplicate column names and treats them case insensitive")
146160

147-
function load!(sch::Tables.Schema, rows, db::DB, nm::AbstractString, name, shouldcreate; temp::Bool=false, ifnotexists::Bool=false, analyze::Bool=false)
161+
function checknames(::Tables.Schema{names}, db_names::Vector{String}) where {names}
162+
table_names = Set(string.(names))
163+
db_names = Set(db_names)
164+
165+
if table_names != db_names
166+
error("Error loading, column names from table $(collect(table_names)) do not match database names $(collect(db_names))")
167+
end
168+
end
169+
170+
function load!(sch::Tables.Schema, rows, db::DB, nm::AbstractString, name, db_tableinfo::TableInfo; temp::Bool=false, ifnotexists::Bool=false, analyze::Bool=false)
148171
# check for case-insensitive duplicate column names (sqlite doesn't allow)
149172
checkdupnames(sch.names)
150173
# create table if needed
151-
shouldcreate && createtable!(db, nm, sch; temp=temp, ifnotexists=ifnotexists)
174+
if db_tableinfo.exists
175+
checknames(sch, db_tableinfo.names)
176+
else
177+
createtable!(db, nm, sch; temp=temp, ifnotexists=ifnotexists)
178+
end
152179
# build insert statement
180+
columns = join(sch.names, ",")
153181
params = chop(repeat("?,", length(sch.names)))
154-
stmt = Stmt(db, "INSERT INTO $nm VALUES ($params)")
182+
stmt = Stmt(db, "INSERT INTO $nm ($columns) VALUES ($params)")
155183
# start a transaction for inserting rows
156184
transaction(db) do
157185
for row in rows
@@ -167,18 +195,23 @@ function load!(sch::Tables.Schema, rows, db::DB, nm::AbstractString, name, shoul
167195
end
168196

169197
# unknown schema case
170-
function load!(::Nothing, rows, db::DB, nm::AbstractString, name, shouldcreate; temp::Bool=false, ifnotexists::Bool=false, analyze::Bool=false)
198+
function load!(::Nothing, rows, db::DB, nm::AbstractString, name, db_tableinfo::TableInfo; temp::Bool=false, ifnotexists::Bool=false, analyze::Bool=false)
171199
state = iterate(rows)
172200
state === nothing && return nm
173201
row, st = state
174202
names = propertynames(row)
175203
sch = Tables.Schema(names, nothing)
176204
checkdupnames(sch.names)
177205
# create table if needed
178-
shouldcreate && createtable!(db, nm, sch; temp=temp, ifnotexists=ifnotexists)
206+
if db_tableinfo.exists
207+
checknames(sch, db_tableinfo.names)
208+
else
209+
createtable!(db, nm, sch; temp=temp, ifnotexists=ifnotexists)
210+
end
179211
# build insert statement
180212
params = chop(repeat("?,", length(names)))
181-
stmt = Stmt(db, "INSERT INTO $nm VALUES ($params)")
213+
columns = join(sch.names, ",")
214+
stmt = Stmt(db, "INSERT INTO $nm ($columns) VALUES ($params)")
182215
# start a transaction for inserting rows
183216
transaction(db) do
184217
while true

test/runtests.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,11 +291,27 @@ row2 = first(r)
291291
@test DBInterface.lastrowid(r) == 3
292292

293293
r = DBInterface.execute(db, "SELECT * FROM T") |> columntable
294-
SQLite.load!(nothing, Tables.rows(r), db, "T2", "T2", true)
294+
SQLite.load!(nothing, Tables.rows(r), db, "T2", "T2", SQLite.tableinfo(db, "T2"))
295295
r2 = DBInterface.execute(db, "SELECT * FROM T2") |> columntable
296296
@test r == r2
297297

298298
# throw informative error on duplicate column names #193
299299
@test_throws ErrorException SQLite.load!((a=[1,2,3], A=[1,2,3]), db)
300300

301+
db = SQLite.DB()
302+
# Table should map by name #216
303+
tbl1 = (a = [1, 2, 3], b = [4, 5, 6])
304+
tbl2 = (b = [7, 8, 9], a = [4, 5, 6])
305+
SQLite.load!(tbl1, db, "data")
306+
SQLite.load!(tbl2, db, "data")
307+
308+
res = DBInterface.execute(db, "SELECT * FROM data") |> columntable
309+
expected = (a=[1, 2, 3, 4, 5, 6], b=[4, 5, 6, 7, 8, 9])
310+
@test res == expected
311+
312+
# Table should error if names don't match #216
313+
tbl3 = (c = [7, 8, 9], a = [4, 5, 6])
314+
@test_throws ErrorException SQLite.load!(tbl3, db, "data")
315+
316+
301317
end # @testset

0 commit comments

Comments
 (0)