Skip to content

Commit 65658e9

Browse files
committed
Make '*' optional for io.read() and io.lines().
1 parent 6190ad9 commit 65658e9

File tree

2 files changed

+90
-22
lines changed

2 files changed

+90
-22
lines changed

compat53/module.lua

Lines changed: 77 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@ if lua_version < "5.3" then
99
-- cache globals in upvalues
1010
local error, ipairs, pairs, pcall, require, select, setmetatable, type =
1111
error, ipairs, pairs, pcall, require, select, setmetatable, type
12-
local debug, math, package, string, table =
13-
debug, math, package, string, table
12+
local debug, io, math, package, string, table =
13+
debug, io, math, package, string, table
14+
local io_lines = io.lines
15+
local io_read = io.read
16+
local unpack = lua_version == "5.1" and unpack or table.unpack
1417

1518
-- create module table
1619
M = {}
@@ -21,6 +24,7 @@ if lua_version < "5.3" then
2124
setmetatable(M, M_meta)
2225

2326
-- create subtables
27+
M.io = setmetatable({}, { __index = io })
2428
M.math = setmetatable({}, { __index = math })
2529
M.string = setmetatable({}, { __index = string })
2630
M.table = setmetatable({}, { __index = table })
@@ -148,15 +152,13 @@ if lua_version < "5.3" then
148152

149153

150154
-- assert should allow non-string error objects
151-
do
152-
function M.assert(cond, ...)
153-
if cond then
154-
return cond, ...
155-
elseif select('#', ...) > 0 then
156-
error((...), 0)
157-
else
158-
error("assertion failed!", 0)
159-
end
155+
function M.assert(cond, ...)
156+
if cond then
157+
return cond, ...
158+
elseif select('#', ...) > 0 then
159+
error((...), 0)
160+
else
161+
error("assertion failed!", 0)
160162
end
161163
end
162164

@@ -180,13 +182,63 @@ if lua_version < "5.3" then
180182
end
181183

182184

185+
-- make '*' optional for io.read and io.lines
186+
do
187+
local function addasterisk(fmt)
188+
if type(fmt) == "string" and fmt:sub(1, 1) ~= "*" then
189+
return "*"..fmt
190+
else
191+
return fmt
192+
end
193+
end
194+
195+
function M.io.read(...)
196+
local n = select('#', ...)
197+
for i = 1, n do
198+
local a = select(i, ...)
199+
local b = addasterisk(a)
200+
-- as an optimization we only allocate a table for the
201+
-- modified format arguments when we have a '*' somewhere.
202+
if a ~= b then
203+
local args = { ... }
204+
args[i] = b
205+
for j = i+1, n do
206+
args[j] = addasterisk(args[j])
207+
end
208+
return io_read(unpack(args, 1, n))
209+
end
210+
end
211+
return io_read(...)
212+
end
213+
214+
-- PUC-Rio Lua 5.1 uses a different implementation for io.lines!
215+
function M.io.lines(...)
216+
local n = select('#', ...)
217+
for i = 2, n do
218+
local a = select(i, ...)
219+
local b = addasterisk(a)
220+
-- as an optimization we only allocate a table for the
221+
-- modified format arguments when we have a '*' somewhere.
222+
if a ~= b then
223+
local args = { ... }
224+
args[i] = b
225+
for j = i+1, n do
226+
args[j] = addasterisk(args[j])
227+
end
228+
return io_lines(unpack(args, 1, n))
229+
end
230+
end
231+
return io_lines(...)
232+
end
233+
end
234+
235+
183236
-- update table library (if C module not available)
184237
if not table_ok then
185238
local table_concat = table.concat
186239
local table_insert = table.insert
187240
local table_remove = table.remove
188241
local table_sort = table.sort
189-
local table_unpack = lua_version == "5.1" and unpack or table.unpack
190242

191243
function M.table.concat(list, sep, i, j)
192244
local mt = gmt(list)
@@ -344,7 +396,7 @@ if lua_version < "5.3" then
344396
i, j = i or 1, j or (has_len and mt.__len(list)) or #list
345397
return unpack_helper(list, i, j)
346398
else
347-
return table_unpack(list, i, j)
399+
return unpack(list, i, j)
348400
end
349401
end
350402
end -- update table library
@@ -358,9 +410,9 @@ if lua_version < "5.3" then
358410
#setmetatable({}, { __len = function() return 1 end }) == 1
359411

360412
-- cache globals in upvalues
361-
local load, loadfile, loadstring, setfenv, unpack, xpcall =
362-
load, loadfile, loadstring, setfenv, unpack, xpcall
363-
local coroutine, io, os = coroutine, io, os
413+
local load, loadfile, loadstring, setfenv, xpcall =
414+
load, loadfile, loadstring, setfenv, xpcall
415+
local coroutine, os = coroutine, os
364416
local coroutine_create = coroutine.create
365417
local coroutine_resume = coroutine.resume
366418
local coroutine_running = coroutine.running
@@ -382,7 +434,6 @@ if lua_version < "5.3" then
382434

383435
-- create subtables
384436
M.coroutine = setmetatable({}, { __index = coroutine })
385-
M.io = setmetatable({}, { __index = io })
386437
M.os = setmetatable({}, { __index = os })
387438
M.package = setmetatable({}, { __index = package })
388439

@@ -734,8 +785,6 @@ if lua_version < "5.3" then
734785
return helper(st, st.f:read(unpack(st, 1, st.n)))
735786
end
736787

737-
local valid_format = { ["*l"] = true, ["*n"] = true, ["*a"] = true }
738-
739788
function M.io.lines(fname, ...)
740789
local doclose, file, msg
741790
if fname ~= nil then
@@ -746,8 +795,15 @@ if lua_version < "5.3" then
746795
end
747796
local st = { f=file, doclose=doclose, n=select('#', ...), ... }
748797
for i = 1, st.n do
749-
if type(st[i]) ~= "number" and not valid_format[st[i]] then
750-
error("bad argument #"..(i+1).." to 'for iterator' (invalid format)", 2)
798+
local t = type(st[i])
799+
if t == "string" then
800+
local fmt = st[i]:match("^%*?([aln])")
801+
if not fmt then
802+
error("bad argument #"..(i+1).." to 'for iterator' (invalid format)", 2)
803+
end
804+
st[i] = "*"..fmt
805+
elseif t ~= "number" then
806+
error("bad argument #"..(i+1).." to 'for iterator' (invalid format)", 2)
751807
end
752808
end
753809
return lines_iterator, st

tests/test.lua

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,18 @@ do
537537
end
538538

539539

540+
___''
541+
do
542+
writefile("data.txt", "123 18.8 hello world\ni'm here\n")
543+
io.input("data.txt")
544+
print(io.read("*n", "*number", "*l", "*a"))
545+
io.input("data.txt")
546+
print(io.read("n", "number", "l", "a"))
547+
io.input(io.stdin)
548+
os.remove("data.txt")
549+
end
550+
551+
540552
___''
541553
do
542554
writefile("data.txt", "123 18.8 hello world\ni'm here\n")
@@ -548,7 +560,7 @@ do
548560
print("io.lines()", l)
549561
break
550562
end
551-
for n1,n2,rest in io.lines("data.txt", "*n", "*n", "*a") do
563+
for n1,n2,rest in io.lines("data.txt", "*n", "n", "*a") do
552564
print("io.lines()", n1, n2, rest)
553565
end
554566
for l in io.lines("data.txt") do

0 commit comments

Comments
 (0)