@@ -35,6 +35,9 @@ import software.aws.toolkits.core.utils.tryOrNull
3535import software.aws.toolkits.core.utils.warn
3636import software.aws.toolkits.telemetry.AuthTelemetry
3737import software.aws.toolkits.telemetry.Result
38+ import java.io.ByteArrayInputStream
39+ import java.io.ByteArrayOutputStream
40+ import java.io.IOException
3841import java.io.InputStream
3942import java.io.OutputStream
4043import java.nio.file.Path
@@ -46,6 +49,7 @@ import java.time.Instant
4649import java.time.ZoneOffset
4750import java.time.format.DateTimeFormatter.ISO_INSTANT
4851import java.util.TimeZone
52+ import java.util.concurrent.ConcurrentHashMap
4953
5054/* *
5155 * Caches the [AccessToken] to disk to allow it to be re-used with other tools such as the CLI.
@@ -98,24 +102,30 @@ class DiskCache(
98102
99103 override fun invalidateClientRegistration (ssoRegion : String ) {
100104 LOG .debug { " invalidateClientRegistration for $ssoRegion " }
105+ InMemoryCache .remove(clientRegistrationCache(ssoRegion).toString())
101106 clientRegistrationCache(ssoRegion).tryDeleteIfExists()
102107 }
103108
104109 override fun loadClientRegistration (cacheKey : ClientRegistrationCacheKey ): ClientRegistration ? {
105110 LOG .debug { " loadClientRegistration for $cacheKey " }
106111 val inputStream = clientRegistrationCache(cacheKey).tryInputStreamIfExists()
107- if (inputStream == null ) {
108- val stage = LoadCredentialStage .ACCESS_FILE
109- LOG .warn { " Failed to load Client Registration: cache file does not exist" }
110- AuthTelemetry .modifyConnection(
111- action = " Load cache file" ,
112- source = " loadClientRegistration" ,
113- result = Result .Failed ,
114- reason = " Failed to load Client Registration" ,
115- reasonDesc = " Load Step:$stage failed. Unable to load file"
116- )
117- return null
118- }
112+ ? : // try to load from in memory cache
113+ return InMemoryCache .get(clientRegistrationCache(cacheKey).toString())?.let { data ->
114+ ByteArrayInputStream (data).use { memoryStream ->
115+ loadClientRegistration(memoryStream)
116+ }
117+ } ? : run {
118+ val stage = LoadCredentialStage .ACCESS_FILE
119+ LOG .warn { " Failed to load Client Registration: cache file does not exist" }
120+ AuthTelemetry .modifyConnection(
121+ action = " Load cache file" ,
122+ source = " loadClientRegistration" ,
123+ result = Result .Failed ,
124+ reason = " Failed to load Client Registration" ,
125+ reasonDesc = " Load Step:$stage failed. Unable to load file"
126+ )
127+ null
128+ }
119129 return loadClientRegistration(inputStream)
120130 }
121131
@@ -130,6 +140,7 @@ class DiskCache(
130140 override fun invalidateClientRegistration (cacheKey : ClientRegistrationCacheKey ) {
131141 LOG .debug { " invalidateClientRegistration for $cacheKey " }
132142 try {
143+ InMemoryCache .remove(clientRegistrationCache(cacheKey).toString())
133144 clientRegistrationCache(cacheKey).tryDeleteIfExists()
134145 } catch (e: Exception ) {
135146 AuthTelemetry .modifyConnection(
@@ -146,6 +157,7 @@ class DiskCache(
146157 override fun invalidateAccessToken (ssoUrl : String ) {
147158 LOG .debug { " invalidateAccessToken for $ssoUrl " }
148159 try {
160+ InMemoryCache .remove(accessTokenCache(ssoUrl).toString())
149161 accessTokenCache(ssoUrl).tryDeleteIfExists()
150162 } catch (e: Exception ) {
151163 AuthTelemetry .modifyConnection(
@@ -162,11 +174,18 @@ class DiskCache(
162174 override fun loadAccessToken (cacheKey : AccessTokenCacheKey ): AccessToken ? {
163175 LOG .debug { " loadAccessToken for $cacheKey " }
164176 val cacheFile = accessTokenCache(cacheKey)
165- val inputStream = cacheFile.tryInputStreamIfExists() ? : return null
166-
167- val token = loadAccessToken(inputStream)
168-
169- return token
177+ // If file exists, returns InputStream, if not returns null
178+ return cacheFile.tryInputStreamIfExists()
179+ // try to load and parse access token, returns AccessToken or null if expired
180+ ?.let { loadAccessToken(it) }
181+ // If file doesn't exist or loadAccessToken failed, try in-memory cache
182+ ? : InMemoryCache .get(cacheFile.toString())?.let { data ->
183+ // If in-memory cache has data, create stream and try to load token
184+ ByteArrayInputStream (data).use { memoryStream ->
185+ loadAccessToken(memoryStream)
186+ }
187+ // If both file system and in-memory cache attempts fail, returns null
188+ }
170189 }
171190
172191 override fun saveAccessToken (cacheKey : AccessTokenCacheKey , accessToken : AccessToken ) {
@@ -180,6 +199,7 @@ class DiskCache(
180199 override fun invalidateAccessToken (cacheKey : AccessTokenCacheKey ) {
181200 LOG .debug { " invalidateAccessToken for $cacheKey " }
182201 try {
202+ InMemoryCache .remove(accessTokenCache(cacheKey).toString())
183203 accessTokenCache(cacheKey).tryDeleteIfExists()
184204 } catch (e: Exception ) {
185205 AuthTelemetry .modifyConnection(
@@ -278,6 +298,14 @@ class DiskCache(
278298 outputStream().use(consumer)
279299 }
280300 } catch (e: Exception ) {
301+ when {
302+ e is IOException -> {
303+ if (e.message?.contains(" No space left on device" ) == true ) {
304+ LOG .warn { " Disk space full. Storing credentials in memory for this session" }
305+ storeInMemory(path, consumer)
306+ }
307+ }
308+ }
281309 AuthTelemetry .modifyConnection(
282310 action = " Write file" ,
283311 source = " writeKey" ,
@@ -294,6 +322,27 @@ class DiskCache(
294322
295323 private fun AccessToken.isDefinitelyExpired (): Boolean = refreshToken == null && ! expiresAt.isNotExpired()
296324
325+ private fun storeInMemory (path : Path , consumer : (OutputStream ) -> Unit ) {
326+ val byteArrayOutputStream = ByteArrayOutputStream ()
327+ consumer(byteArrayOutputStream)
328+ val data = byteArrayOutputStream.toByteArray()
329+ InMemoryCache .put(path.toString(), data)
330+ }
331+
332+ private object InMemoryCache {
333+ private val cache = ConcurrentHashMap <String , ByteArray >()
334+
335+ fun put (key : String , value : ByteArray ) {
336+ cache[key] = value
337+ }
338+
339+ fun get (key : String ): ByteArray? = cache[key]
340+
341+ fun remove (key : String ) {
342+ cache.remove(key)
343+ }
344+ }
345+
297346 private class CliCompatibleInstantDeserializer : StdDeserializer <Instant >(Instant : :class.java) {
298347 override fun deserialize (parser : JsonParser , context : DeserializationContext ): Instant {
299348 val dateString = parser.valueAsString
0 commit comments