Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.back.koreaTravelGuide.common.config

import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.data.redis.connection.RedisConnectionFactory
import org.springframework.data.redis.core.RedisTemplate
import org.springframework.data.redis.serializer.StringRedisSerializer

@Configuration
class RedisConfig {
@Bean
fun redisTemplate(connectionFactory: RedisConnectionFactory): RedisTemplate<String, String> {
val template = RedisTemplate<String, String>()

template.connectionFactory = connectionFactory

// Key와 Value의 Serializer를 String으로 설정

template.keySerializer = StringRedisSerializer()

template.valueSerializer = StringRedisSerializer()

return template
}
}
Original file line number Diff line number Diff line change
@@ -1,38 +1,67 @@
package com.back.koreaTravelGuide.security
package com.back.koreaTravelGuide.common.security

import com.back.koreaTravelGuide.domain.user.enums.UserRole
import com.back.koreaTravelGuide.domain.user.repository.UserRepository
import jakarta.servlet.http.Cookie
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import org.springframework.beans.factory.annotation.Value
import org.springframework.data.redis.core.RedisTemplate
import org.springframework.security.core.Authentication
import org.springframework.security.oauth2.core.user.OAuth2User
import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler
import org.springframework.stereotype.Component
import org.springframework.transaction.annotation.Transactional
import java.util.concurrent.TimeUnit

@Component
class CustomOAuth2LoginSuccessHandler(
private val jwtTokenProvider: JwtTokenProvider,
private val userRepository: UserRepository,
private val redisTemplate: RedisTemplate<String, String>,
@Value("\${jwt.refresh-token-expiration-days}") private val refreshTokenExpirationDays: Long,
) : SimpleUrlAuthenticationSuccessHandler() {
@Transactional
override fun onAuthenticationSuccess(
request: HttpServletRequest,
response: HttpServletResponse,
authentication: Authentication,
) {
val oAuth2User = authentication.principal as OAuth2User
val email = oAuth2User.attributes["email"] as String
val customUser = authentication.principal as CustomOAuth2User

val email = customUser.email

val user = userRepository.findByEmail(email)!!

if (user.role == UserRole.PENDING) {
val registerToken = jwtTokenProvider.createRegisterToken(user.id!!)

val targetUrl = "http://localhost:3000/signup/role?token=$registerToken"

redirectStrategy.sendRedirect(request, response, targetUrl)
} else {
val accessToken = jwtTokenProvider.createAccessToken(user.id!!, user.role)

val refreshToken = jwtTokenProvider.createRefreshToken(user.id!!)

val redisKey = "refreshToken:${user.id}"

redisTemplate.opsForValue().set(redisKey, refreshToken, refreshTokenExpirationDays, TimeUnit.DAYS)

val cookie =
Cookie("refreshToken", refreshToken).apply {
isHttpOnly = true

secure = true

path = "/"

maxAge = (refreshTokenExpirationDays * 24 * 60 * 60).toInt()
}

response.addCookie(cookie)

val targetUrl = "http://localhost:3000/oauth/callback?accessToken=$accessToken"

redirectStrategy.sendRedirect(request, response, targetUrl)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.back.koreaTravelGuide.security
package com.back.koreaTravelGuide.common.security

import org.springframework.security.core.GrantedAuthority
import org.springframework.security.oauth2.core.user.DefaultOAuth2User
Expand All @@ -8,4 +8,15 @@ class CustomOAuth2User(
val email: String,
authorities: Collection<GrantedAuthority>,
attributes: Map<String, Any>,
) : DefaultOAuth2User(authorities, attributes, "email")
val nameAttributeKey: String,
) : DefaultOAuth2User(authorities, attributes, nameAttributeKey) {
override fun getName(): String {
val nameAttribute = getAttribute<Any>(nameAttributeKey)

if (nameAttribute is Map<*, *>) {
return nameAttribute["id"] as String
}

return nameAttribute.toString()
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.back.koreaTravelGuide.security
package com.back.koreaTravelGuide.common.security

import com.back.koreaTravelGuide.domain.user.entity.User
import com.back.koreaTravelGuide.domain.user.enums.UserRole
Expand Down Expand Up @@ -43,11 +43,14 @@ class CustomOAuth2UserService(

val authorities = listOf(SimpleGrantedAuthority("ROLE_${user.role.name}"))

val userNameAttributeName = userRequest.clientRegistration.providerDetails.userInfoEndpoint.userNameAttributeName

return CustomOAuth2User(
id = user.id!!,
email = user.email,
authorities = authorities,
attributes = attributes,
nameAttributeKey = userNameAttributeName,
)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package com.back.koreaTravelGuide.security
package com.back.koreaTravelGuide.common.security

import jakarta.servlet.FilterChain
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import org.springframework.data.redis.core.RedisTemplate
import org.springframework.security.core.context.SecurityContextHolder
import org.springframework.stereotype.Component
import org.springframework.web.filter.OncePerRequestFilter

@Component
class JwtAuthenticationFilter(
private val jwtTokenProvider: JwtTokenProvider,
private val redisTemplate: RedisTemplate<String, String>,
) : OncePerRequestFilter() {
override fun doFilterInternal(
request: HttpServletRequest,
Expand All @@ -18,8 +20,11 @@ class JwtAuthenticationFilter(
) {
val token = resolveToken(request)

if (token != null && jwtTokenProvider.validateToken(token)) {
val isBlacklisted = if (token != null) redisTemplate.opsForValue().get(token) != null else false

if (token != null && !isBlacklisted && jwtTokenProvider.validateToken(token)) {
val authentication = jwtTokenProvider.getAuthentication(token)

SecurityContextHolder.getContext().authentication = authentication
}

Expand All @@ -28,6 +33,7 @@ class JwtAuthenticationFilter(

private fun resolveToken(request: HttpServletRequest): String? {
val bearerToken = request.getHeader("Authorization")

return if (bearerToken != null && bearerToken.startsWith("Bearer ")) {
bearerToken.substring(7)
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package com.back.koreaTravelGuide.security
package com.back.koreaTravelGuide.common.security

import com.back.koreaTravelGuide.domain.user.enums.UserRole
import io.jsonwebtoken.Claims
import io.jsonwebtoken.Jwts
import io.jsonwebtoken.security.Keys
import org.slf4j.LoggerFactory
import org.springframework.beans.factory.annotation.Value
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken
import org.springframework.security.core.Authentication
Expand All @@ -17,8 +18,12 @@ import javax.crypto.SecretKey
class JwtTokenProvider(
@Value("\${jwt.secret-key}") private val secretKey: String,
@Value("\${jwt.access-token-expiration-minutes}") private val accessTokenExpirationMinutes: Long,
@Value("\${jwt.refresh-token-expiration-days}") private val refreshTokenExpirationDays: Long,
) {
private val logger = LoggerFactory.getLogger(JwtTokenProvider::class.java)

private val key: SecretKey by lazy {

Keys.hmacShaKeyFor(Base64.getEncoder().encode(secretKey.toByteArray()))
}

Expand All @@ -27,6 +32,7 @@ class JwtTokenProvider(
role: UserRole,
): String {
val now = Date()

val expiryDate = Date(now.time + accessTokenExpirationMinutes * 60 * 1000)

return Jwts.builder()
Expand All @@ -38,8 +44,22 @@ class JwtTokenProvider(
.compact()
}

fun createRefreshToken(userId: Long): String {
val now = Date()

val expiryDate = Date(now.time + refreshTokenExpirationDays * 24 * 60 * 60 * 1000)

return Jwts.builder()
.subject(userId.toString())
.issuedAt(now)
.expiration(expiryDate)
.signWith(key)
.compact()
}

fun createRegisterToken(userId: Long): String {
val now = Date()

val expiryDate = Date(now.time + 5 * 60 * 1000)

return Jwts.builder()
Expand All @@ -53,21 +73,37 @@ class JwtTokenProvider(
fun validateToken(token: String): Boolean {
try {
getClaimsFromToken(token)

return true
} catch (e: Exception) {
logger.error("Token validation error: ${e.message}")

return false
}
}

fun getAuthentication(token: String): Authentication {
val claims = getClaimsFromToken(token)

val userId = claims.subject.toLong()

val role = claims["role"] as? String ?: "ROLE_PENDING"

val authorities = listOf(SimpleGrantedAuthority(role))

return UsernamePasswordAuthenticationToken(userId, null, authorities)
}

fun getUserIdFromToken(token: String): Long {
return getClaimsFromToken(token).subject.toLong()
}

fun getRemainingTime(token: String): Long {
val expiration = getClaimsFromToken(token).expiration

return expiration.time - Date().time
}

private fun getClaimsFromToken(token: String): Claims {
return Jwts.parser()
.verifyWith(key)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package com.back.koreaTravelGuide.domain.auth.controller

import com.back.koreaTravelGuide.common.ApiResponse
import com.back.koreaTravelGuide.domain.auth.dto.request.UserRoleUpdateRequest
import com.back.koreaTravelGuide.domain.auth.dto.response.AccessTokenResponse
import com.back.koreaTravelGuide.domain.auth.dto.response.LoginResponse
import com.back.koreaTravelGuide.domain.auth.service.AuthService
import io.swagger.v3.oas.annotations.Operation
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import org.springframework.beans.factory.annotation.Value
import org.springframework.http.ResponseEntity
import org.springframework.security.core.annotation.AuthenticationPrincipal
import org.springframework.web.bind.annotation.CookieValue
import org.springframework.web.bind.annotation.PostMapping
import org.springframework.web.bind.annotation.RequestBody
import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.RestController

@RestController
@RequestMapping("/api/auth")
class AuthController(
private val authService: AuthService,
@Value("\${jwt.refresh-token-expiration-days}") private val refreshTokenExpirationDays: Long,
) {
@PostMapping("/refresh")
fun refreshAccessToken(
@CookieValue("refreshToken") refreshToken: String,
response: HttpServletResponse,
): ResponseEntity<ApiResponse<AccessTokenResponse>> {
val (newAccessToken, newRefreshToken) = authService.refreshAccessToken(refreshToken)

val cookie =
jakarta.servlet.http.Cookie("refreshToken", newRefreshToken).apply {
isHttpOnly = true
secure = true
path = "/"
maxAge = (refreshTokenExpirationDays * 24 * 60 * 60).toInt()
}
response.addCookie(cookie)

return ResponseEntity.ok(ApiResponse("Access Token이 성공적으로 재발급되었습니다.", AccessTokenResponse(newAccessToken)))
}

@Operation(summary = "신규 사용자 역할 선택")
@PostMapping("/role")
fun updateUserRole(
@AuthenticationPrincipal userId: Long,
@RequestBody request: UserRoleUpdateRequest,
): ResponseEntity<ApiResponse<LoginResponse>> {
val loginResponse = authService.updateRoleAndLogin(userId, request.role)
return ResponseEntity.ok(ApiResponse("역할이 선택되었으며 로그인에 성공했습니다.", loginResponse))
}

@Operation(summary = "로그아웃")
@PostMapping("/logout")
fun logout(request: HttpServletRequest): ResponseEntity<ApiResponse<Unit>> {
val token =
request.getHeader("Authorization")?.substring(7)
?: throw IllegalArgumentException("토큰이 없습니다.")

authService.logout(token)

return ResponseEntity.ok(ApiResponse("로그아웃 되었습니다."))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.back.koreaTravelGuide.domain.auth.dto.request

import com.back.koreaTravelGuide.domain.user.enums.UserRole

data class UserRoleUpdateRequest(
val role: UserRole,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.back.koreaTravelGuide.domain.auth.dto.response

data class AccessTokenResponse(
val accessToken: String,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.back.koreaTravelGuide.domain.auth.dto.response

data class LoginResponse(
val accessToken: String,
)
Loading