Skip to content

Commit 5900252

Browse files
authored
Add is_protobuf_message/1 guard for message identification (#428)
1 parent 4328993 commit 5900252

File tree

6 files changed

+91
-3
lines changed

6 files changed

+91
-3
lines changed

lib/protobuf.ex

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,37 @@ defmodule Protobuf do
5151
"""
5252
@type unknown_field() :: {field_number :: integer(), wire_type(), value :: any()}
5353

54+
@doc """
55+
Checks if the given value is a Protobuf message struct.
56+
57+
This guard checks for the `__protobuf__: true` marker that is automatically added
58+
to all Protobuf message structs.
59+
60+
## Examples
61+
62+
defmodule MyMessage do
63+
use Protobuf, syntax: :proto3
64+
field :name, 1, type: :string
65+
end
66+
67+
message = %MyMessage{name: "test"}
68+
69+
iex> is_protobuf_message(message)
70+
true
71+
72+
iex> is_protobuf_message(%{})
73+
false
74+
75+
Can be used in guards:
76+
77+
def process_message(msg) when is_protobuf_message(msg) do
78+
msg.name
79+
end
80+
81+
"""
82+
defguard is_protobuf_message(value)
83+
when is_map(value) and :erlang.is_map_key(:__protobuf__, value)
84+
5485
defmacro __using__(opts) do
5586
quote location: :keep do
5687
import Protobuf.DSL, only: [field: 3, field: 2, oneof: 2, extend: 4, extensions: 1]

lib/protobuf/dsl.ex

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,10 @@ defmodule Protobuf.DSL do
459459
end
460460

461461
unknown_fields = {:__unknown_fields__, _default = []}
462+
protobuf_marker = {:__protobuf__, _default = true}
462463

463-
struct_fields = regular_fields ++ oneof_fields ++ extension_fields ++ [unknown_fields]
464+
struct_fields =
465+
regular_fields ++ oneof_fields ++ extension_fields ++ [unknown_fields, protobuf_marker]
464466

465467
quote do
466468
defstruct unquote(Macro.escape(struct_fields))

lib/protobuf/dsl/typespecs.ex

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,12 @@ defmodule Protobuf.DSL.Typespecs do
4545
)}
4646
]
4747

48-
field_specs = regular_fields ++ oneof_fields ++ extension_fields ++ unknown_fields
48+
protobuf_marker = [
49+
{:__protobuf__, quote(do: true)}
50+
]
51+
52+
field_specs =
53+
regular_fields ++ oneof_fields ++ extension_fields ++ unknown_fields ++ protobuf_marker
4954

5055
quote do: %__MODULE__{unquote_splicing(field_specs)}
5156
end

test/protobuf/dsl/typespecs_test.exs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ defmodule Protobuf.DSL.TypespecsTest do
77

88
@unknown_fields_spec quote(
99
do: [
10-
__unknown_fields__: [Protobuf.unknown_field()]
10+
__unknown_fields__: [Protobuf.unknown_field()],
11+
__protobuf__: true
1112
]
1213
)
1314

test/protobuf/protobuf_test.exs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
defmodule Protobuf.ProtobufTest do
22
use ExUnit.Case, async: false
3+
import Protobuf, only: [is_protobuf_message: 1]
34

45
test "load_extensions/0 is a noop" do
56
assert loaded_extensions() == 19
@@ -47,6 +48,45 @@ defmodule Protobuf.ProtobufTest do
4748
end
4849
end
4950

51+
describe "is_protobuf_message/1" do
52+
test "returns true for protobuf message structs" do
53+
message = %TestMsg.Foo{a: 42}
54+
assert is_protobuf_message(message)
55+
end
56+
57+
test "returns false for non-protobuf structs" do
58+
refute is_protobuf_message(%URI{})
59+
end
60+
61+
test "returns false for non-structs" do
62+
refute is_protobuf_message(%{})
63+
refute is_protobuf_message("string")
64+
refute is_protobuf_message(42)
65+
refute is_protobuf_message(nil)
66+
end
67+
68+
test "works in pattern matching" do
69+
message = %TestMsg.Foo{a: 42}
70+
71+
result =
72+
case message do
73+
%{__protobuf__: true} -> :protobuf
74+
_ -> :not_protobuf
75+
end
76+
77+
assert result == :protobuf
78+
end
79+
80+
test "works in guards" do
81+
message = %TestMsg.Foo{a: 42}
82+
assert check_protobuf_with_guard(message) == true
83+
assert check_protobuf_with_guard(%URI{}) == false
84+
end
85+
end
86+
87+
defp check_protobuf_with_guard(value) when is_protobuf_message(value), do: true
88+
defp check_protobuf_with_guard(_), do: false
89+
5090
defp loaded_extensions do
5191
Enum.count(:persistent_term.get(), &match?({{Protobuf.Extension, _, _}, _}, &1))
5292
end

test/protobuf/protoc/generator/message_test.exs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ defmodule Protobuf.Protoc.Generator.MessageTest do
2222
quote(
2323
do:
2424
t() :: %Pkg.Name.Foo{
25+
__protobuf__: true,
2526
__unknown_fields__: [Protobuf.unknown_field()]
2627
}
2728
)
@@ -47,6 +48,7 @@ defmodule Protobuf.Protoc.Generator.MessageTest do
4748
quote(
4849
do:
4950
t() :: %Foo{
51+
__protobuf__: true,
5052
__unknown_fields__: [Protobuf.unknown_field()]
5153
}
5254
)
@@ -175,6 +177,7 @@ defmodule Protobuf.Protoc.Generator.MessageTest do
175177
quote(
176178
do:
177179
t() :: %Foo{
180+
__protobuf__: true,
178181
__unknown_fields__: [Protobuf.unknown_field()],
179182
a: integer(),
180183
b: String.t(),
@@ -366,6 +369,7 @@ defmodule Protobuf.Protoc.Generator.MessageTest do
366369
quote(
367370
do:
368371
t() :: %Foo{
372+
__protobuf__: true,
369373
__unknown_fields__: [Protobuf.unknown_field()],
370374
bar: Bar.t() | nil,
371375
baz: [Baz.t()]
@@ -435,6 +439,7 @@ defmodule Protobuf.Protoc.Generator.MessageTest do
435439
quote(
436440
do:
437441
t() :: %FooBar.AbCd.Foo{
442+
__protobuf__: true,
438443
__unknown_fields__: [Protobuf.unknown_field()],
439444
a: %{optional(integer()) => FooBar.AbCd.Bar.t() | nil}
440445
}
@@ -524,6 +529,7 @@ defmodule Protobuf.Protoc.Generator.MessageTest do
524529
quote(
525530
do:
526531
t() :: %FooBar.AbCd.Foo{
532+
__protobuf__: true,
527533
__unknown_fields__: [Protobuf.unknown_field()],
528534
a: OtherPkg.MsgFoo.t() | nil
529535
}
@@ -584,6 +590,7 @@ defmodule Protobuf.Protoc.Generator.MessageTest do
584590
quote(
585591
do:
586592
t() :: %MyPkg.Foo{
593+
__protobuf__: true,
587594
__unknown_fields__: [Protobuf.unknown_field()],
588595
a: MyPkg.Foo.Nested.t() | nil
589596
}
@@ -689,6 +696,7 @@ defmodule Protobuf.Protoc.Generator.MessageTest do
689696
quote(
690697
do:
691698
t() :: %Foo{
699+
__protobuf__: true,
692700
__unknown_fields__: [Protobuf.unknown_field()],
693701
first: {:a, integer()} | {:b, integer()} | nil,
694702
other: integer() | nil,
@@ -772,6 +780,7 @@ defmodule Protobuf.Protoc.Generator.MessageTest do
772780
quote(
773781
do:
774782
t() :: %FooBar.AbCd.Foo{
783+
__protobuf__: true,
775784
__unknown_fields__: [Protobuf.unknown_field()],
776785
a: [FooBar.AbCd.EnumFoo.t()]
777786
}

0 commit comments

Comments
 (0)