2929#include " llvm/Object/Binary.h"
3030#include " llvm/Object/ObjectFile.h"
3131#include " llvm/Support/Casting.h"
32+ #include " llvm/Support/Compiler.h"
3233#include " llvm/Support/Compression.h"
3334#include " llvm/Support/Debug.h"
3435#include " llvm/Support/EndianStream.h"
@@ -1127,13 +1128,116 @@ CompressedOffloadBundle::compress(llvm::compression::Params P,
11271128 llvm::StringRef (FinalBuffer.data (), FinalBuffer.size ()));
11281129}
11291130
1131+ // Use packed structs to avoid padding, such that the structs map the serialized
1132+ // format.
1133+ LLVM_PACKED_START
1134+ union RawCompressedBundleHeader {
1135+ struct CommonFields {
1136+ uint32_t Magic;
1137+ uint16_t Version;
1138+ uint16_t Method;
1139+ };
1140+
1141+ struct V1Header {
1142+ CommonFields Common;
1143+ uint32_t UncompressedFileSize;
1144+ uint64_t Hash;
1145+ };
1146+
1147+ struct V2Header {
1148+ CommonFields Common;
1149+ uint32_t FileSize;
1150+ uint32_t UncompressedFileSize;
1151+ uint64_t Hash;
1152+ };
1153+
1154+ struct V3Header {
1155+ CommonFields Common;
1156+ uint64_t FileSize;
1157+ uint64_t UncompressedFileSize;
1158+ uint64_t Hash;
1159+ };
1160+
1161+ CommonFields Common;
1162+ V1Header V1;
1163+ V2Header V2;
1164+ V3Header V3;
1165+ };
1166+ LLVM_PACKED_END
1167+
1168+ // Helper method to get header size based on version
1169+ static size_t getHeaderSize (uint16_t Version) {
1170+ switch (Version) {
1171+ case 1 :
1172+ return sizeof (RawCompressedBundleHeader::V1Header);
1173+ case 2 :
1174+ return sizeof (RawCompressedBundleHeader::V2Header);
1175+ case 3 :
1176+ return sizeof (RawCompressedBundleHeader::V3Header);
1177+ default :
1178+ llvm_unreachable (" Unsupported version" );
1179+ }
1180+ }
1181+
1182+ Expected<CompressedOffloadBundle::CompressedBundleHeader>
1183+ CompressedOffloadBundle::CompressedBundleHeader::tryParse (StringRef Blob) {
1184+ assert (Blob.size () >= sizeof (RawCompressedBundleHeader::CommonFields));
1185+ assert (llvm::identify_magic (Blob) ==
1186+ llvm::file_magic::offload_bundle_compressed);
1187+
1188+ RawCompressedBundleHeader Header;
1189+ memcpy (&Header, Blob.data (), std::min (Blob.size (), sizeof (Header)));
1190+
1191+ CompressedBundleHeader Normalized;
1192+ Normalized.Version = Header.Common .Version ;
1193+
1194+ size_t RequiredSize = getHeaderSize (Normalized.Version );
1195+ if (Blob.size () < RequiredSize)
1196+ return createStringError (inconvertibleErrorCode (),
1197+ " Compressed bundle header size too small" );
1198+
1199+ switch (Normalized.Version ) {
1200+ case 1 :
1201+ Normalized.UncompressedFileSize = Header.V1 .UncompressedFileSize ;
1202+ Normalized.Hash = Header.V1 .Hash ;
1203+ break ;
1204+ case 2 :
1205+ Normalized.FileSize = Header.V2 .FileSize ;
1206+ Normalized.UncompressedFileSize = Header.V2 .UncompressedFileSize ;
1207+ Normalized.Hash = Header.V2 .Hash ;
1208+ break ;
1209+ case 3 :
1210+ Normalized.FileSize = Header.V3 .FileSize ;
1211+ Normalized.UncompressedFileSize = Header.V3 .UncompressedFileSize ;
1212+ Normalized.Hash = Header.V3 .Hash ;
1213+ break ;
1214+ default :
1215+ return createStringError (inconvertibleErrorCode (),
1216+ " Unknown compressed bundle version" );
1217+ }
1218+
1219+ // Determine compression format
1220+ switch (Header.Common .Method ) {
1221+ case static_cast <uint16_t >(compression::Format::Zlib):
1222+ case static_cast <uint16_t >(compression::Format::Zstd):
1223+ Normalized.CompressionFormat =
1224+ static_cast <compression::Format>(Header.Common .Method );
1225+ break ;
1226+ default :
1227+ return createStringError (inconvertibleErrorCode (),
1228+ " Unknown compressing method" );
1229+ }
1230+
1231+ return Normalized;
1232+ }
1233+
11301234llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
11311235CompressedOffloadBundle::decompress (const llvm::MemoryBuffer &Input,
11321236 bool Verbose) {
11331237 StringRef Blob = Input.getBuffer ();
11341238
11351239 // Check minimum header size (using V1 as it's the smallest)
1136- if (Blob.size () < V1HeaderSize )
1240+ if (Blob.size () < sizeof (RawCompressedBundleHeader::CommonFields) )
11371241 return llvm::MemoryBuffer::getMemBufferCopy (Blob);
11381242
11391243 if (llvm::identify_magic (Blob) !=
@@ -1143,76 +1247,28 @@ CompressedOffloadBundle::decompress(const llvm::MemoryBuffer &Input,
11431247 return llvm::MemoryBuffer::getMemBufferCopy (Blob);
11441248 }
11451249
1146- size_t CurrentOffset = MagicSize;
1147-
1148- // Read version
1149- uint16_t ThisVersion;
1150- memcpy (&ThisVersion, Blob.data () + CurrentOffset, sizeof (uint16_t ));
1151- CurrentOffset += VersionFieldSize;
1152-
1153- // Verify header size based on version
1154- if (ThisVersion >= 2 && ThisVersion <= 3 ) {
1155- size_t RequiredSize = (ThisVersion == 2 ) ? V2HeaderSize : V3HeaderSize;
1156- if (Blob.size () < RequiredSize)
1157- return createStringError (inconvertibleErrorCode (),
1158- " Compressed bundle header size too small" );
1159- }
1160-
1161- // Read compression method
1162- uint16_t CompressionMethod;
1163- memcpy (&CompressionMethod, Blob.data () + CurrentOffset, sizeof (uint16_t ));
1164- CurrentOffset += MethodFieldSize;
1165-
1166- // Read total file size (version 2+)
1167- uint64_t TotalFileSize = 0 ;
1168- if (ThisVersion >= 2 ) {
1169- if (ThisVersion == 2 ) {
1170- uint32_t TotalFileSize32;
1171- memcpy (&TotalFileSize32, Blob.data () + CurrentOffset, sizeof (uint32_t ));
1172- TotalFileSize = TotalFileSize32;
1173- CurrentOffset += FileSizeFieldSizeV2;
1174- } else { // Version 3
1175- memcpy (&TotalFileSize, Blob.data () + CurrentOffset, sizeof (uint64_t ));
1176- CurrentOffset += FileSizeFieldSizeV3;
1177- }
1178- }
1250+ Expected<CompressedBundleHeader> HeaderOrErr =
1251+ CompressedBundleHeader::tryParse (Blob);
1252+ if (!HeaderOrErr)
1253+ return HeaderOrErr.takeError ();
11791254
1180- // Read uncompressed size
1181- uint64_t UncompressedSize = 0 ;
1182- if (ThisVersion <= 2 ) {
1183- uint32_t UncompressedSize32;
1184- memcpy (&UncompressedSize32, Blob.data () + CurrentOffset, sizeof (uint32_t ));
1185- UncompressedSize = UncompressedSize32;
1186- CurrentOffset += UncompressedSizeFieldSizeV2;
1187- } else { // Version 3
1188- memcpy (&UncompressedSize, Blob.data () + CurrentOffset, sizeof (uint64_t ));
1189- CurrentOffset += UncompressedSizeFieldSizeV3;
1190- }
1255+ const CompressedBundleHeader &Normalized = *HeaderOrErr;
1256+ unsigned ThisVersion = Normalized.Version ;
1257+ size_t HeaderSize = getHeaderSize (ThisVersion);
11911258
1192- // Read hash
1193- uint64_t StoredHash;
1194- memcpy (&StoredHash, Blob.data () + CurrentOffset, sizeof (uint64_t ));
1195- CurrentOffset += HashFieldSize;
1259+ llvm::compression::Format CompressionFormat = Normalized.CompressionFormat ;
11961260
1197- // Determine compression format
1198- llvm::compression::Format CompressionFormat;
1199- if (CompressionMethod ==
1200- static_cast <uint16_t >(llvm::compression::Format::Zlib))
1201- CompressionFormat = llvm::compression::Format::Zlib;
1202- else if (CompressionMethod ==
1203- static_cast <uint16_t >(llvm::compression::Format::Zstd))
1204- CompressionFormat = llvm::compression::Format::Zstd;
1205- else
1206- return createStringError (inconvertibleErrorCode (),
1207- " Unknown compressing method" );
1261+ size_t TotalFileSize = Normalized.FileSize .value_or (0 );
1262+ size_t UncompressedSize = Normalized.UncompressedFileSize ;
1263+ auto StoredHash = Normalized.Hash ;
12081264
12091265 llvm::Timer DecompressTimer (" Decompression Timer" , " Decompression time" ,
12101266 *ClangOffloadBundlerTimerGroup);
12111267 if (Verbose)
12121268 DecompressTimer.startTimer ();
12131269
12141270 SmallVector<uint8_t , 0 > DecompressedData;
1215- StringRef CompressedData = Blob.substr (CurrentOffset );
1271+ StringRef CompressedData = Blob.substr (HeaderSize );
12161272 if (llvm::Error DecompressionError = llvm::compression::decompress (
12171273 CompressionFormat, llvm::arrayRefFromStringRef (CompressedData),
12181274 DecompressedData, UncompressedSize))
0 commit comments