Skip to content

Commit 5137c33

Browse files
authored
Do not allow guards in assert/1 (#13817)
1 parent 5cb8633 commit 5137c33

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

lib/ex_unit/lib/ex_unit/assertions.ex

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ defmodule ExUnit.Assertions do
121121
122122
Even though the match works, `assert` still expects a truth
123123
value. In such cases, simply use `==/2` or `match?/2`.
124+
125+
If you need more complex pattern matching using guards, you
126+
need to use `match?/2`:
127+
128+
assert match?([%{id: id} | _] when is_integer(id), records)
129+
124130
"""
125131
defmacro assert({:=, meta, [left, right]} = assertion) do
126132
code = escape_quoted(:assert, meta, assertion)
@@ -357,6 +363,23 @@ defmodule ExUnit.Assertions do
357363
end
358364

359365
@doc false
366+
def __match__({:when, _, _} = left, right, _, _, _) do
367+
suggestion =
368+
quote do
369+
assert match?(unquote(left), unquote(right))
370+
end
371+
372+
raise ArgumentError, """
373+
invalid pattern in assert/1:
374+
375+
#{Macro.to_string(left) |> Inspect.Error.pad(2)}
376+
377+
To assert with guards, use match?/2:
378+
379+
#{Macro.to_string(suggestion) |> Inspect.Error.pad(2)}
380+
"""
381+
end
382+
360383
def __match__(left, right, code, check, caller) do
361384
left = __expand_pattern__(left, caller)
362385
vars = collect_vars_from_pattern(left)

lib/ex_unit/test/ex_unit/assertions_test.exs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,33 @@ defmodule ExUnit.AssertionsTest do
256256
end
257257
end
258258

259+
test "assert match with `when` in the pattern fails" do
260+
message = """
261+
invalid pattern in assert\/1:
262+
263+
x when is_map(x)
264+
265+
To assert with guards, use match?/2:
266+
267+
assert match?(x when is_map(x), %{})
268+
"""
269+
270+
assert_raise ArgumentError, message, fn ->
271+
Code.eval_string("""
272+
defmodule AssertGuard do
273+
import ExUnit.Assertions
274+
275+
def run do
276+
assert (x when is_map(x)) = %{}
277+
end
278+
end
279+
""")
280+
end
281+
after
282+
:code.purge(AssertGuard)
283+
:code.delete(AssertGuard)
284+
end
285+
259286
test "assert match with __ENV__ in the pattern" do
260287
message =
261288
ExUnit.CaptureIO.capture_io(:stderr, fn ->

0 commit comments

Comments
 (0)