diff --git a/lib/ex_limiter/plug.ex b/lib/ex_limiter/plug.ex index bded1b5..e6967c7 100644 --- a/lib/ex_limiter/plug.ex +++ b/lib/ex_limiter/plug.ex @@ -35,37 +35,11 @@ defmodule ExLimiter.Plug do """ import Plug.Conn + alias ExLimiter.Bucket + @compile_opts Application.compile_env(:ex_limiter, __MODULE__, []) - @limiter @compile_opts[:limiter] || ExLimiter - - defmodule Config do - @moduledoc false - @compile_opts Application.compile_env(:ex_limiter, ExLimiter.Plug, []) - @limit @compile_opts[:limit] || 10 - @scale @compile_opts[:scale] || 1000 - @fallback @compile_opts[:fallback] || ExLimiter.Plug - - defstruct scale: @scale, - limit: @limit, - bucket: &ExLimiter.Plug.get_bucket/1, - consumes: nil, - decorate: nil, - fallback: @fallback - - def new(opts) do - contents = - opts - |> Map.new() - |> Map.put_new(:consumes, fn _ -> 1 end) - |> Map.put_new(:decorate, &ExLimiter.Plug.decorate/2) - - struct(__MODULE__, contents) - end - end - def get_bucket(%{private: %{phoenix_controller: contr, phoenix_action: ac}} = conn) do - "#{contr}.#{ac}.#{ip(conn)}" - end + def get_bucket(%{private: %{phoenix_controller: contr, phoenix_action: ac}} = conn), do: "#{contr}.#{ac}.#{ip(conn)}" def render_error(conn, :rate_limited) do conn @@ -76,28 +50,46 @@ defmodule ExLimiter.Plug do @spec decorate(Plug.Conn.t(), {:ok, Bucket.t()} | {:rate_limited, bucket_name :: binary}) :: Plug.Conn.t() def decorate(conn, _), do: conn - def init(opts), do: Config.new(opts) + def consume(_conn), do: 1 + + def init(opts \\ []) do + @compile_opts + |> Keyword.merge(opts) + |> Keyword.validate!( + limiter: ExLimiter, + limit: 10, + scale: 1000, + fallback: __MODULE__, + bucket: &__MODULE__.get_bucket/1, + consumes: &__MODULE__.consume/1, + decorate: &__MODULE__.decorate/2 + ) + |> Map.new() + end + + def call(conn, config) do + %{ + limiter: limiter, + bucket: bucket_fun, + scale: scale, + limit: limit, + consumes: consume_fun, + decorate: decorate_fun, + fallback: fallback + } = config - def call(conn, %Config{ - bucket: bucket_fun, - scale: scale, - limit: limit, - consumes: consume_fun, - decorate: decorate_fun, - fallback: fallback - }) do bucket_name = bucket_fun.(conn) - case @limiter.consume(bucket_name, consume_fun.(conn), scale: scale, limit: limit) do + case limiter.consume(bucket_name, consume_fun.(conn), scale: scale, limit: limit) do {:ok, bucket} = response -> - remaining = @limiter.remaining(bucket, scale: scale, limit: limit) + remaining = limiter.remaining(bucket, scale: scale, limit: limit) conn |> put_rate_limit_headers(limit, scale, remaining) |> decorate_fun.(response) {:error, :rate_limited} -> - remaining = @limiter.remaining(%ExLimiter.Bucket{key: bucket_name}, scale: scale, limit: limit) + remaining = limiter.remaining(%Bucket{key: bucket_name}, scale: scale, limit: limit) conn |> put_rate_limit_headers(limit, scale, remaining) diff --git a/test/ex_limiter/plug_test.exs b/test/ex_limiter/plug_test.exs index bd30459..22b9195 100644 --- a/test/ex_limiter/plug_test.exs +++ b/test/ex_limiter/plug_test.exs @@ -83,7 +83,7 @@ defmodule ExLimiter.PlugTest do end defp setup_limiter(_) do - [limiter: ExLimiter.Plug.Config.new(consumes: &consumes/1)] + [limiter: ExLimiter.Plug.init(consumes: &consumes/1)] end defp consumes(%{params: %{"count" => count}}), do: count