Skip to content

Commit 34120e7

Browse files
committed
initial commit
1 parent 06181dd commit 34120e7

File tree

1 file changed

+66
-17
lines changed
  • plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso

1 file changed

+66
-17
lines changed

plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/DiskCache.kt

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ import software.aws.toolkits.core.utils.tryOrNull
3535
import software.aws.toolkits.core.utils.warn
3636
import software.aws.toolkits.telemetry.AuthTelemetry
3737
import software.aws.toolkits.telemetry.Result
38+
import java.io.ByteArrayInputStream
39+
import java.io.ByteArrayOutputStream
40+
import java.io.IOException
3841
import java.io.InputStream
3942
import java.io.OutputStream
4043
import java.nio.file.Path
@@ -46,6 +49,7 @@ import java.time.Instant
4649
import java.time.ZoneOffset
4750
import java.time.format.DateTimeFormatter.ISO_INSTANT
4851
import java.util.TimeZone
52+
import java.util.concurrent.ConcurrentHashMap
4953

5054
/**
5155
* Caches the [AccessToken] to disk to allow it to be re-used with other tools such as the CLI.
@@ -98,24 +102,30 @@ class DiskCache(
98102

99103
override fun invalidateClientRegistration(ssoRegion: String) {
100104
LOG.debug { "invalidateClientRegistration for $ssoRegion" }
105+
InMemoryCache.remove(clientRegistrationCache(ssoRegion).toString())
101106
clientRegistrationCache(ssoRegion).tryDeleteIfExists()
102107
}
103108

104109
override fun loadClientRegistration(cacheKey: ClientRegistrationCacheKey): ClientRegistration? {
105110
LOG.debug { "loadClientRegistration for $cacheKey" }
106111
val inputStream = clientRegistrationCache(cacheKey).tryInputStreamIfExists()
107-
if (inputStream == null) {
108-
val stage = LoadCredentialStage.ACCESS_FILE
109-
LOG.warn { "Failed to load Client Registration: cache file does not exist" }
110-
AuthTelemetry.modifyConnection(
111-
action = "Load cache file",
112-
source = "loadClientRegistration",
113-
result = Result.Failed,
114-
reason = "Failed to load Client Registration",
115-
reasonDesc = "Load Step:$stage failed. Unable to load file"
116-
)
117-
return null
118-
}
112+
?: //try to load from in memory cache
113+
return InMemoryCache.get(clientRegistrationCache(cacheKey).toString())?.let { data ->
114+
ByteArrayInputStream(data).use { memoryStream ->
115+
loadClientRegistration(memoryStream)
116+
}
117+
} ?: run {
118+
val stage = LoadCredentialStage.ACCESS_FILE
119+
LOG.warn { "Failed to load Client Registration: cache file does not exist" }
120+
AuthTelemetry.modifyConnection(
121+
action = "Load cache file",
122+
source = "loadClientRegistration",
123+
result = Result.Failed,
124+
reason = "Failed to load Client Registration",
125+
reasonDesc = "Load Step:$stage failed. Unable to load file"
126+
)
127+
null
128+
}
119129
return loadClientRegistration(inputStream)
120130
}
121131

@@ -130,6 +140,7 @@ class DiskCache(
130140
override fun invalidateClientRegistration(cacheKey: ClientRegistrationCacheKey) {
131141
LOG.debug { "invalidateClientRegistration for $cacheKey" }
132142
try {
143+
InMemoryCache.remove(clientRegistrationCache(cacheKey).toString())
133144
clientRegistrationCache(cacheKey).tryDeleteIfExists()
134145
} catch (e: Exception) {
135146
AuthTelemetry.modifyConnection(
@@ -146,6 +157,7 @@ class DiskCache(
146157
override fun invalidateAccessToken(ssoUrl: String) {
147158
LOG.debug { "invalidateAccessToken for $ssoUrl" }
148159
try {
160+
InMemoryCache.remove(accessTokenCache(ssoUrl).toString())
149161
accessTokenCache(ssoUrl).tryDeleteIfExists()
150162
} catch (e: Exception) {
151163
AuthTelemetry.modifyConnection(
@@ -162,11 +174,18 @@ class DiskCache(
162174
override fun loadAccessToken(cacheKey: AccessTokenCacheKey): AccessToken? {
163175
LOG.debug { "loadAccessToken for $cacheKey" }
164176
val cacheFile = accessTokenCache(cacheKey)
165-
val inputStream = cacheFile.tryInputStreamIfExists() ?: return null
166-
167-
val token = loadAccessToken(inputStream)
168-
169-
return token
177+
// If file exists, returns InputStream, if not returns null
178+
return cacheFile.tryInputStreamIfExists()
179+
//try to load and parse access token, returns AccessToken or null if expired
180+
?.let { loadAccessToken(it) }
181+
// If file doesn't exist or loadAccessToken failed, try in-memory cache
182+
?: InMemoryCache.get(cacheFile.toString())?.let { data ->
183+
// If in-memory cache has data, create stream and try to load token
184+
ByteArrayInputStream(data).use { memoryStream ->
185+
loadAccessToken(memoryStream)
186+
}
187+
// If both file system and in-memory cache attempts fail, returns null
188+
}
170189
}
171190

172191
override fun saveAccessToken(cacheKey: AccessTokenCacheKey, accessToken: AccessToken) {
@@ -180,6 +199,7 @@ class DiskCache(
180199
override fun invalidateAccessToken(cacheKey: AccessTokenCacheKey) {
181200
LOG.debug { "invalidateAccessToken for $cacheKey" }
182201
try {
202+
InMemoryCache.remove(accessTokenCache(cacheKey).toString())
183203
accessTokenCache(cacheKey).tryDeleteIfExists()
184204
} catch (e: Exception) {
185205
AuthTelemetry.modifyConnection(
@@ -278,6 +298,14 @@ class DiskCache(
278298
outputStream().use(consumer)
279299
}
280300
} catch (e: Exception) {
301+
when {
302+
e is IOException -> {
303+
if (e.message?.contains("No space left on device") == true) {
304+
LOG.warn { "Disk space full. Storing credentials in memory for this session" }
305+
storeInMemory(path, consumer)
306+
}
307+
}
308+
}
281309
AuthTelemetry.modifyConnection(
282310
action = "Write file",
283311
source = "writeKey",
@@ -294,6 +322,27 @@ class DiskCache(
294322

295323
private fun AccessToken.isDefinitelyExpired(): Boolean = refreshToken == null && !expiresAt.isNotExpired()
296324

325+
private fun storeInMemory(path: Path, consumer: (OutputStream) -> Unit) {
326+
val byteArrayOutputStream = ByteArrayOutputStream()
327+
consumer(byteArrayOutputStream)
328+
val data = byteArrayOutputStream.toByteArray()
329+
InMemoryCache.put(path.toString(), data)
330+
}
331+
332+
private object InMemoryCache {
333+
private val cache = ConcurrentHashMap<String, ByteArray>()
334+
335+
fun put(key: String, value: ByteArray) {
336+
cache[key] = value
337+
}
338+
339+
fun get(key: String): ByteArray? = cache[key]
340+
341+
fun remove(key: String) {
342+
cache.remove(key)
343+
}
344+
}
345+
297346
private class CliCompatibleInstantDeserializer : StdDeserializer<Instant>(Instant::class.java) {
298347
override fun deserialize(parser: JsonParser, context: DeserializationContext): Instant {
299348
val dateString = parser.valueAsString

0 commit comments

Comments
 (0)