Skip to content

Commit 8cc4fe8

Browse files
committed
remove duplicate clientcredential parsing #35
1 parent c04307d commit 8cc4fe8

File tree

8 files changed

+47
-78
lines changed

8 files changed

+47
-78
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ type handler to not require client credentials:
135135

136136
```scala
137137
class MyTokenEndpoint extends TokenEndpoint {
138-
val passwordNoCred = new Password(ClientCredentialFetcher) {
138+
val passwordNoCred = new Password() {
139139
override def clientCredentialRequired = false
140140
}
141141

scala-oauth2-core/src/main/scala/scalaoauth2/provider/GrantHandler.scala

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ trait GrantHandler {
1313
*/
1414
def clientCredentialRequired = true
1515

16-
def handleRequest[U](request: AuthorizationRequest, dataHandler: DataHandler[U]): Future[GrantHandlerResult]
16+
def handleRequest[U](request: AuthorizationRequest, optionalClientCredential: Option[ClientCredential], dataHandler: DataHandler[U]): Future[GrantHandlerResult]
1717

1818
/**
1919
* Returns valid access token.
@@ -43,10 +43,10 @@ trait GrantHandler {
4343
}
4444
}
4545

46-
class RefreshToken(clientCredentialFetcher: ClientCredentialFetcher) extends GrantHandler {
46+
class RefreshToken extends GrantHandler {
4747

48-
override def handleRequest[U](request: AuthorizationRequest, dataHandler: DataHandler[U]): Future[GrantHandlerResult] = {
49-
val clientCredential = clientCredentialFetcher.fetch(request).getOrElse(throw new InvalidRequest("Authorization header is invalid"))
48+
override def handleRequest[U](request: AuthorizationRequest, optionalClientCredential: Option[ClientCredential], dataHandler: DataHandler[U]): Future[GrantHandlerResult] = {
49+
val clientCredential = optionalClientCredential.getOrElse(throw new InvalidRequest("Client credential is required"))
5050
val refreshToken = request.requireRefreshToken
5151

5252
dataHandler.findAuthInfoByRefreshToken(refreshToken).flatMap { authInfoOption =>
@@ -68,54 +68,53 @@ class RefreshToken(clientCredentialFetcher: ClientCredentialFetcher) extends Gra
6868
}
6969
}
7070

71-
class Password(clientCredentialFetcher: ClientCredentialFetcher) extends GrantHandler {
71+
class Password extends GrantHandler {
72+
73+
override def handleRequest[U](request: AuthorizationRequest, optionalClientCredential: Option[ClientCredential], dataHandler: DataHandler[U]): Future[GrantHandlerResult] = {
74+
if (clientCredentialRequired && optionalClientCredential.isEmpty) {
75+
throw new InvalidRequest("Client credential is required")
76+
}
7277

73-
override def handleRequest[U](request: AuthorizationRequest, dataHandler: DataHandler[U]): Future[GrantHandlerResult] = {
74-
val clientCredential = clientCredentialFetcher.fetch(request)
75-
if (clientCredentialRequired && clientCredential.isEmpty)
76-
throw new InvalidRequest("Authorization header is invalid")
7778
val username = request.requireUsername
7879
val password = request.requirePassword
7980

8081
dataHandler.findUser(username, password).flatMap { userOption =>
8182
val user = userOption.getOrElse(throw new InvalidGrant("username or password is incorrect"))
8283
val scope = request.scope
83-
val clientId = clientCredential.map { _.clientId }
84+
val clientId = optionalClientCredential.map { _.clientId }
8485
val authInfo = AuthInfo(user, clientId, scope, None)
8586

8687
issueAccessToken(dataHandler, authInfo)
8788
}
8889
}
8990
}
9091

91-
class ClientCredentials(clientCredentialFetcher: ClientCredentialFetcher) extends GrantHandler {
92+
class ClientCredentials extends GrantHandler {
9293

93-
override def handleRequest[U](request: AuthorizationRequest, dataHandler: DataHandler[U]): Future[GrantHandlerResult] = {
94-
val clientCredential = clientCredentialFetcher.fetch(request).getOrElse(throw new InvalidRequest("Authorization header is invalid"))
95-
val clientSecret = clientCredential.clientSecret
96-
val clientId = clientCredential.clientId
94+
override def handleRequest[U](request: AuthorizationRequest, optionalClientCredential: Option[ClientCredential], dataHandler: DataHandler[U]): Future[GrantHandlerResult] = {
95+
val clientCredential = optionalClientCredential.getOrElse(throw new InvalidRequest("Client credential is required"))
9796
val scope = request.scope
9897

99-
dataHandler.findClientUser(clientId, clientSecret, scope).flatMap { userOption =>
100-
val user = userOption.getOrElse(throw new InvalidGrant("client_id or client_secret or scope is incorrect"))
101-
val authInfo = AuthInfo(user, Some(clientId), scope, None)
98+
dataHandler.findClientUser(clientCredential.clientId, clientCredential.clientSecret, scope).flatMap { optionalUser =>
99+
val user = optionalUser.getOrElse(throw new InvalidGrant("client_id or client_secret or scope is incorrect"))
100+
val authInfo = AuthInfo(user, Some(clientCredential.clientId), scope, None)
102101

103102
issueAccessToken(dataHandler, authInfo)
104103
}
105104
}
106105

107106
}
108107

109-
class AuthorizationCode(clientCredentialFetcher: ClientCredentialFetcher) extends GrantHandler {
108+
class AuthorizationCode extends GrantHandler {
110109

111-
override def handleRequest[U](request: AuthorizationRequest, dataHandler: DataHandler[U]): Future[GrantHandlerResult] = {
112-
val clientCredential = clientCredentialFetcher.fetch(request).getOrElse(throw new InvalidRequest("Authorization header is invalid"))
110+
override def handleRequest[U](request: AuthorizationRequest, optionalClientCredential: Option[ClientCredential], dataHandler: DataHandler[U]): Future[GrantHandlerResult] = {
111+
val clientCredential = optionalClientCredential.getOrElse(throw new InvalidRequest("Client credential is required"))
113112
val clientId = clientCredential.clientId
114113
val code = request.requireCode
115114
val redirectUri = request.redirectUri
116115

117-
dataHandler.findAuthInfoByCode(code).flatMap { authInfoOption =>
118-
val authInfo = authInfoOption.getOrElse(throw new InvalidGrant("Authorized information is not found by the code"))
116+
dataHandler.findAuthInfoByCode(code).flatMap { optionalAuthInfo =>
117+
val authInfo = optionalAuthInfo.getOrElse(throw new InvalidGrant("Authorized information is not found by the code"))
119118
if (authInfo.clientId != Some(clientId)) {
120119
throw new InvalidClient
121120
}

scala-oauth2-core/src/main/scala/scalaoauth2/provider/TokenEndpoint.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ trait TokenEndpoint {
77
val fetcher = ClientCredentialFetcher
88

99
val handlers = Map(
10-
"authorization_code" -> new AuthorizationCode(fetcher),
11-
"refresh_token" -> new RefreshToken(fetcher),
12-
"client_credentials" -> new ClientCredentials(fetcher),
13-
"password" -> new Password(fetcher)
10+
"authorization_code" -> new AuthorizationCode(),
11+
"refresh_token" -> new RefreshToken(),
12+
"client_credentials" -> new ClientCredentials(),
13+
"password" -> new Password()
1414
)
1515

1616
def handleRequest[U](request: AuthorizationRequest, dataHandler: DataHandler[U]): Future[Either[OAuthError, GrantHandlerResult]] = try {
@@ -22,7 +22,7 @@ trait TokenEndpoint {
2222
if (!validClient) {
2323
Future.successful(Left(throw new InvalidClient()))
2424
} else {
25-
handler.handleRequest(request, dataHandler).map(Right(_))
25+
handler.handleRequest(request, Some(clientCredential), dataHandler).map(Right(_))
2626
}
2727
}.recover {
2828
case e: OAuthError => Left(e)
@@ -31,7 +31,7 @@ trait TokenEndpoint {
3131
if (handler.clientCredentialRequired) {
3232
throw new InvalidRequest("Client credential is not found")
3333
} else {
34-
handler.handleRequest(request, dataHandler).map(Right(_)).recover {
34+
handler.handleRequest(request, None, dataHandler).map(Right(_)).recover {
3535
case e: OAuthError => Left(e)
3636
}
3737
}

scala-oauth2-core/src/test/scala/scalaoauth2/provider/AuthorizationCodeSpec.scala

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
package scalaoauth2.provider
22

3-
import org.scalatest._
43
import org.scalatest.Matchers._
4+
import org.scalatest._
55
import org.scalatest.concurrent.ScalaFutures
66

7-
import scala.concurrent.Await
8-
import scala.concurrent.duration._
97
import scala.concurrent.Future
108

119
class AuthorizationCodeSpec extends FlatSpec with ScalaFutures {
1210

1311
it should "handle request" in {
14-
val authorizationCode = new AuthorizationCode(new MockClientCredentialFetcher())
12+
val authorizationCode = new AuthorizationCode()
1513
val request = AuthorizationRequest(Map(), Map("code" -> Seq("code1"), "redirect_uri" -> Seq("http://example.com/")))
16-
val f = authorizationCode.handleRequest(request, new MockDataHandler() {
14+
val f = authorizationCode.handleRequest(request, Some(ClientCredential("clientId1", "clientSecret1")), new MockDataHandler() {
1715

1816
override def findAuthInfoByCode(code: String): Future[Option[AuthInfo[User]]] = Future.successful(Some(
1917
AuthInfo(user = MockUser(10000, "username"), clientId = Some("clientId1"), scope = Some("all"), redirectUri = Some("http://example.com/"))
@@ -32,9 +30,9 @@ class AuthorizationCodeSpec extends FlatSpec with ScalaFutures {
3230
}
3331

3432
it should "handle request if redirectUrl is none" in {
35-
val authorizationCode = new AuthorizationCode(new MockClientCredentialFetcher())
33+
val authorizationCode = new AuthorizationCode()
3634
val request = AuthorizationRequest(Map(), Map("code" -> Seq("code1"), "redirect_uri" -> Seq("http://example.com/")))
37-
val f = authorizationCode.handleRequest(request, new MockDataHandler() {
35+
val f = authorizationCode.handleRequest(request, Some(ClientCredential("clientId1", "clientSecret1")), new MockDataHandler() {
3836

3937
override def findAuthInfoByCode(code: String): Future[Option[AuthInfo[MockUser]]] = Future.successful(Some(
4038
AuthInfo(user = MockUser(10000, "username"), clientId = Some("clientId1"), scope = Some("all"), redirectUri = None)
@@ -51,10 +49,4 @@ class AuthorizationCodeSpec extends FlatSpec with ScalaFutures {
5149
result.scope should be (Some("all"))
5250
}
5351
}
54-
55-
class MockClientCredentialFetcher extends ClientCredentialFetcher {
56-
57-
override def fetch(request: AuthorizationRequest): Option[ClientCredential] = Some(ClientCredential("clientId1", "clientSecret1"))
58-
59-
}
6052
}
Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
package scalaoauth2.provider
22

3-
import org.scalatest._
43
import org.scalatest.Matchers._
4+
import org.scalatest._
55
import org.scalatest.concurrent.ScalaFutures
66

7-
import scala.concurrent.Await
8-
import scala.concurrent.duration._
97
import scala.concurrent.Future
108

119
class ClientCredentialsSpec extends FlatSpec with ScalaFutures {
1210

1311
it should "handle request" in {
14-
val clientCredentials = new ClientCredentials(new MockClientCredentialFetcher())
12+
val clientCredentials = new ClientCredentials()
1513
val request = AuthorizationRequest(Map(), Map("scope" -> Seq("all")))
16-
val f = clientCredentials.handleRequest(request, new MockDataHandler() {
14+
val f = clientCredentials.handleRequest(request, Some(ClientCredential("clientId1", "clientSecret1")), new MockDataHandler() {
1715

1816
override def findClientUser(clientId: String, clientSecret: String, scope: Option[String]): Future[Option[User]] = Future.successful(Some(MockUser(10000, "username")))
1917

@@ -28,10 +26,4 @@ class ClientCredentialsSpec extends FlatSpec with ScalaFutures {
2826
result.scope should be (Some("all"))
2927
}
3028
}
31-
32-
class MockClientCredentialFetcher extends ClientCredentialFetcher {
33-
34-
override def fetch(request: AuthorizationRequest): Option[ClientCredential] = Some(ClientCredential("clientId1", "clientSecret1"))
35-
36-
}
3729
}

scala-oauth2-core/src/test/scala/scalaoauth2/provider/PasswordSpec.scala

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@ import scala.concurrent.Future
88

99
class PasswordSpec extends FlatSpec with ScalaFutures {
1010

11-
val passwordClientCredReq = new Password(new MockClientCredentialFetcher())
12-
val passwordNoClientCredReq = new Password(new MockClientCredentialFetcher()) {
11+
val passwordClientCredReq = new Password()
12+
val passwordNoClientCredReq = new Password() {
1313
override def clientCredentialRequired = false
1414
}
1515

16-
"Password when client credential required" should "handle request" in handlesRequest(passwordClientCredReq)
17-
"Password when client credential not required" should "handle request" in handlesRequest(passwordNoClientCredReq)
16+
"Password when client credential required" should "handle request" in handlesRequest(passwordClientCredReq, Some(ClientCredential("clientId1", "clientSecret1")))
17+
"Password when client credential not required" should "handle request" in handlesRequest(passwordNoClientCredReq, None)
1818

19-
def handlesRequest(password: Password) = {
19+
def handlesRequest(password: Password, clientCredential: Option[ClientCredential]) = {
2020
val request = AuthorizationRequest(Map(), Map("username" -> Seq("user"), "password" -> Seq("pass"), "scope" -> Seq("all")))
21-
val f = password.handleRequest(request, new MockDataHandler() {
21+
val f = password.handleRequest(request, clientCredential, new MockDataHandler() {
2222

2323
override def findUser(username: String, password: String): Future[Option[User]] = Future.successful(Some(MockUser(10000, "username")))
2424

@@ -34,10 +34,4 @@ class PasswordSpec extends FlatSpec with ScalaFutures {
3434
result.scope should be(Some("all"))
3535
}
3636
}
37-
38-
class MockClientCredentialFetcher extends ClientCredentialFetcher {
39-
40-
override def fetch(request: AuthorizationRequest): Option[ClientCredential] = Some(ClientCredential("clientId1", "clientSecret1"))
41-
42-
}
4337
}

scala-oauth2-core/src/test/scala/scalaoauth2/provider/RefreshTokenSpec.scala

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,14 @@ import org.scalatest.FlatSpec
44
import org.scalatest.Matchers._
55
import org.scalatest.concurrent.ScalaFutures
66

7-
import scala.concurrent.Await
8-
import scala.concurrent.duration._
97
import scala.concurrent.Future
108

119
class RefreshTokenSpec extends FlatSpec with ScalaFutures {
1210

1311
it should "handle request" in {
14-
val refreshToken = new RefreshToken(new MockClientCredentialFetcher())
12+
val refreshToken = new RefreshToken()
1513
val request = AuthorizationRequest(Map(), Map("refresh_token" -> Seq("refreshToken1")))
16-
val f = refreshToken.handleRequest(request, new MockDataHandler() {
14+
val f = refreshToken.handleRequest(request, Some(ClientCredential("clientId1", "clientSecret1")), new MockDataHandler() {
1715

1816
override def findAuthInfoByRefreshToken(refreshToken: String): Future[Option[AuthInfo[User]]] =
1917
Future.successful(Some(AuthInfo(user = MockUser(10000, "username"), clientId = Some("clientId1"), scope = None, redirectUri = None)))
@@ -30,10 +28,4 @@ class RefreshTokenSpec extends FlatSpec with ScalaFutures {
3028
result.scope should be (None)
3129
}
3230
}
33-
34-
class MockClientCredentialFetcher extends ClientCredentialFetcher {
35-
36-
override def fetch(request: AuthorizationRequest): Option[ClientCredential] = Some(ClientCredential("clientId1", "clientSecret1"))
37-
38-
}
3931
}

scala-oauth2-core/src/test/scala/scalaoauth2/provider/TokenEndPointSpec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class TokenEndPointSpec extends FlatSpec with ScalaFutures {
100100
)
101101

102102
val dataHandler = successfulDataHandler()
103-
val passwordNoCred = new Password(ClientCredentialFetcher) {
103+
val passwordNoCred = new Password() {
104104
override def clientCredentialRequired = false
105105
}
106106
class MyTokenEndpoint extends TokenEndpoint {
@@ -165,7 +165,7 @@ class TokenEndPointSpec extends FlatSpec with ScalaFutures {
165165

166166
object TestTokenEndpoint extends TokenEndpoint {
167167
override val handlers = Map(
168-
"password" -> new Password(fetcher)
168+
"password" -> new Password()
169169
)
170170
}
171171

0 commit comments

Comments
 (0)